"examples/flash_decoding/test_example_flash_decoding.py" did not exist on "d4f096efcc30547c658b3092e8a5730207fbc45a"
Commit eab47249 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Adapt rocm and support `T.gemm` with transpose_b=False for amd backend (#327)



* [Enhancement] Update GEMM and ROCm Integration

- Removed the restriction on transposing matrix B for CDNA in `gemm.cc`, allowing for more flexible matrix operations.
- Added a new debug header file `debug.h` for enhanced debugging capabilities in ROCm kernels.
- Updated `codegen_hip.cc` to include the new debug header and improved handling of float16 and bfloat16 types in vector element stores.
- Refactored `rt_mod_hip.cc` to return a ROCM module directly from `BuildTileLangHIPWithoutCompile`, enhancing the module creation process.
- Introduced a new ROCm utility in `rocm.py` for linking and managing ROCm paths, improving the build process for ROCm applications.
- Updated tests to reflect changes in GEMM configurations and ensure compatibility with the new features.

These changes enhance the flexibility and debugging capabilities of the GEMM operations and improve the integration with the ROCm backend.

* [Fix] Corrected syntax error in pyproject.toml and improved error message formatting in rocm.py

- Added missing quotation mark for "HSA" in the `select` section of `pyproject.toml`.
- Simplified the error message formatting in `get_rocm_arch` function of `rocm.py` for better readability and consistency.

* lint fix

* Update tilelang/jit/adapter/wrapper.py
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* lint fix

---------
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
parent 2cec52aa
...@@ -13,7 +13,7 @@ column_limit = 100 ...@@ -13,7 +13,7 @@ column_limit = 100
indent_width = 4 indent_width = 4
[tool.codespell] [tool.codespell]
ignore-words-list = "nd, te, ist, LOD, offen, NotIn" ignore-words-list = "nd, te, ist, LOD, offen, NotIn, HSA"
skip = [ skip = [
"build", "build",
"3rdparty", "3rdparty",
......
...@@ -253,8 +253,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -253,8 +253,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0) << "WGMMA only support B in shared."; ICHECK(0) << "WGMMA only support B in shared.";
} }
} else if (TargetIsCDNA(T.target)) { } else if (TargetIsCDNA(T.target)) {
ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";
const int warp_size = 64; const int warp_size = 64;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(T.block_size / warp_size, T.target);
......
...@@ -109,6 +109,7 @@ std::string CodeGenTileLangHIP::Finish() { ...@@ -109,6 +109,7 @@ std::string CodeGenTileLangHIP::Finish() {
decl_stream << "#include <tl_templates/hip/reduce.h>\n"; decl_stream << "#include <tl_templates/hip/reduce.h>\n";
decl_stream << "#include <tl_templates/hip/ldsm.h>\n"; decl_stream << "#include <tl_templates/hip/ldsm.h>\n";
decl_stream << "#include <tl_templates/hip/threadblock_swizzle.h>\n"; decl_stream << "#include <tl_templates/hip/threadblock_swizzle.h>\n";
decl_stream << "#include <tl_templates/hip/debug.h>\n";
decl_stream << "\n"; decl_stream << "\n";
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -502,11 +503,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t, ...@@ -502,11 +503,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
stream << "(" << value << " << " << i % 4 * 8 << ");\n"; stream << "(" << value << " << " << i % 4 * 8 << ");\n";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2]
<< access[i % 2] << " = " << value << ";\n"; << ")))->" << access[i % 2] << "))) = " << value << ";\n";
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" stream << "*((bfloat16_t*)(&((half2*)(&(" << vec << "." << access[i / 2]
<< access[i % 2] << " = " << value << ";\n"; << ")))->" << access[i % 2] << "))) = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
......
...@@ -73,7 +73,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -73,7 +73,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
} }
String BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangHIP cg; CodeGenTileLangHIP cg;
...@@ -92,7 +92,8 @@ String BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ...@@ -92,7 +92,8 @@ String BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
if (const auto *f = Registry::Get("tilelang_callback_hip_postproc")) { if (const auto *f = Registry::Get("tilelang_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
return String(code); return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code,
std::string());
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_hip") TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
.set_body_typed(BuildTileLangHIP); .set_body_typed(BuildTileLangHIP);
......
#pragma once #pragma once
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <rocwmma/rocwmma.hpp> #include <rocwmma/rocwmma.hpp>
using ck_tile::half_t;
#define HIPRT_INF_F __int_as_float(0x7f800000) #define HIPRT_INF_F __int_as_float(0x7f800000)
#define HIPRT_NEGINF_F __int_as_float(0xff800000) #define HIPRT_NEGINF_F __int_as_float(0xff800000)
#define HIPRT_NAN_F __int_as_float(0x7fffffff) #define HIPRT_NAN_F __int_as_float(0x7fffffff)
...@@ -33,7 +32,6 @@ using ck_tile::half_t; ...@@ -33,7 +32,6 @@ using ck_tile::half_t;
#define hsqrt __ocml_sqrt_f16 #define hsqrt __ocml_sqrt_f16
using float16_t = _Float16; using float16_t = _Float16;
using float16x2 = using float16x2 =
__attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t; __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4 = using float16x4 =
...@@ -43,6 +41,26 @@ using float16x8 = ...@@ -43,6 +41,26 @@ using float16x8 =
using float16x16 = using float16x16 =
__attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t; __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using half_t = float16_t;
using bfloat16_t = __hip_bfloat16;
struct bfloat16x2 {
bfloat16_t data[2];
};
struct bfloat16x4 {
bfloat16_t data[4];
};
struct bfloat16x8 {
bfloat16_t data[8];
};
struct bfloat16x16 {
bfloat16_t data[16];
};
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
......
#pragma once
#include <hip/hip_runtime.h>
// Base template declaration
template <typename T> __device__ void debug_print_var(const char *msg, T var);
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
const char *safe_msg = msg;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
"char value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for unsigned char type
template <>
__device__ void debug_print_var<unsigned char>(const char *msg,
unsigned char var) {
const char *safe_msg = msg;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned char value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for int type
template <> __device__ void debug_print_var<int>(const char *msg, int var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for unsigned int type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
unsigned int var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned int value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for float type
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for double type
template <>
__device__ void debug_print_var<double>(const char *msg, double var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
"value=%lf\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for bool type
template <> __device__ void debug_print_var<bool>(const char *msg, bool var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
"value=%s\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z,
var ? "true" : "false");
}
// Specialization for short type
template <> __device__ void debug_print_var<short>(const char *msg, short var) {
const char *safe_msg = msg;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=short "
"value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for unsigned short type
template <>
__device__ void debug_print_var<unsigned short>(const char *msg,
unsigned short var) {
const char *safe_msg = msg;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned short value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Template declaration for device-side debug printing (buffer only)
template <typename T>
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
int index, T var);
// Specialization for signed char type
template <>
__device__ void
debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
int index, signed char var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=signed char value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for unsigned char type
template <>
__device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
int index, unsigned char var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=unsigned char value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for integer type
template <>
__device__ void debug_print_buffer_value<int>(const char *msg,
const char *buf_name, int index,
int var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
// Specialization for float type
template <>
__device__ void debug_print_buffer_value<float>(const char *msg,
const char *buf_name, int index,
float var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
// Specialization for half_t type
template <>
__device__ void debug_print_buffer_value<half_t>(const char *msg,
const char *buf_name,
int index, half_t var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
float value = static_cast<float>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for double type
template <>
__device__ void debug_print_buffer_value<double>(const char *msg,
const char *buf_name,
int index, double var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
...@@ -6,10 +6,12 @@ namespace tl { ...@@ -6,10 +6,12 @@ namespace tl {
// ref to bitblas/tl/mfma_macro_generator.py::kPack // ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
bool TransposeB, int kPack, typename A_type, typename B_type, bool TransposeB, bool clear_accum, int kPack, typename A_type,
typename C_type, typename AccDataType = float> typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp { class GemmTensorOp {
public: public:
static_assert(!clear_accum, "clear_accum=true is not supported yet");
static constexpr int micro_size_x = 16; static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16; static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 16; static constexpr int micro_size_k = 16;
...@@ -128,25 +130,37 @@ public: ...@@ -128,25 +130,37 @@ public:
const auto l = warp_m * warp_row_tiles + i * micro_size_x; const auto l = warp_m * warp_row_tiles + i * micro_size_x;
const auto r = ki * (kPack * micro_size_k); const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); if constexpr (TransposeA) {
A_local[i * kPack * local_size_a + local_id] = auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>( A_local[i * kPack * local_size_a + local_id] =
l + row, r + col)]; A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + col, r + row)];
} else {
auto [row, col] = reverse_index_map(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)];
}
} }
} }
// Fetch B into register // Fetch B into register
for (int j = 0; j < warp_cols; j++) { for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y; const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * (kPack * micro_size_k); const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); if constexpr (TransposeB) {
B_local[j * kPack * local_size_b + local_id] = auto [row, col] = reverse_index_map(lane_id, local_id);
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_local[j * kPack * local_size_b + local_id] =
l + row, r + col)]; B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)];
}
} }
} }
// Compute // Compute
for (int kp = 0; kp < kPack; kp++) { for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) { for (int i = 0; i < warp_rows; ++i) {
...@@ -189,10 +203,17 @@ public: ...@@ -189,10 +203,17 @@ public:
const auto l = warp_n * warp_col_tiles + j * micro_size_y; const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * kPack * micro_size_k; const auto r = ki * kPack * micro_size_k;
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) { for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id); if constexpr (TransposeB) {
B_local[j * local_size_b + local_id] = auto [row, col] = reverse_index_map(lane_id, local_id);
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_local[j * local_size_b + local_id] =
l + row, r + col)]; B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
B_local[j * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)];
}
} }
} }
...@@ -218,20 +239,22 @@ public: ...@@ -218,20 +239,22 @@ public:
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int kPack, typename A_type, typename B_type, bool trans_B, bool clear_accum, int kPack, typename A_type,
typename C_type> typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, using Compute =
trans_B, kPack, A_type, B_type, C_type>; GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>;
Compute::body(pA, pB, accum); Compute::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int kPack, typename A_type, typename B_type, bool trans_B, bool clear_accum, int kPack, typename A_type,
typename C_type> typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, using Compute =
trans_B, kPack, A_type, B_type, C_type>; GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>;
Compute::body_rs(pA, pB, accum); Compute::body_rs(pA, pB, accum);
} }
......
...@@ -61,7 +61,7 @@ def run_gemm( ...@@ -61,7 +61,7 @@ def run_gemm(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=0,
num_threads=128, num_threads=128,
k_pack=1, k_pack=1,
): ):
...@@ -91,15 +91,14 @@ def run_gemm( ...@@ -91,15 +91,14 @@ def run_gemm(
A = A.T A = A.T
if trans_B: if trans_B:
B = B.T B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float)) return (A @ B).to(torch.__getattribute__(out_dtype))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_gemm_f16f32f32_nt(): def test_gemm_f16f32f32_nt():
run_gemm(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2) run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2)
......
...@@ -97,7 +97,7 @@ from . import ( ...@@ -97,7 +97,7 @@ from . import (
engine, # noqa: F401 engine, # noqa: F401
) )
from .engine import lower # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility for ROCm backend"""
# ruff: noqa
import re
import subprocess
import os
from os.path import join, exists
import tvm._ffi
from tvm._ffi.base import py_str
import tvm.runtime
import tvm.target
from tvm.contrib import utils
def find_lld(required=True):
"""Find ld.lld in system.
Parameters
----------
required : bool
Whether it is required,
runtime error will be raised if the compiler is required.
Returns
-------
valid_list : list of str
List of possible paths.
Note
----
This function will first search ld.lld that
matches the major llvm version that built with tvm
"""
lld_list = []
major = tvm.target.codegen.llvm_version_major(allow_none=True)
if major is not None:
lld_list += [f"ld.lld-{major}.0"]
lld_list += [f"ld.lld-{major}"]
lld_list += ["ld.lld"]
lld_list += [f"/opt/rocm/llvm/bin/{x}" for x in lld_list]
valid_list = [utils.which(x) for x in lld_list]
valid_list = [x for x in valid_list if x]
if not valid_list and required:
raise RuntimeError("cannot find ld.lld, candidates are: " + str(lld_list))
return valid_list
def rocm_link(in_file, out_file, lld=None):
"""Link relocatable ELF object to shared ELF object using lld
Parameters
----------
in_file : str
Input file name (relocatable ELF object file)
out_file : str
Output file name (shared ELF object file)
lld : str, optional
The lld linker, if not specified,
we will try to guess the matched clang version.
"""
# if our result has undefined symbols, it will fail to load
# (hipModuleLoad/hipModuleLoadData), but with a somewhat opaque message
# so we have ld.lld check this here.
# If you get a complaint about missing symbols you might want to check the
# list of bitcode files below.
args = [
lld if lld is not None else find_lld()[0],
"--no-undefined",
"-shared",
in_file,
"-o",
out_file,
]
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Linking error using ld.lld:\n"
msg += py_str(out)
raise RuntimeError(msg)
@tvm._ffi.register_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
Parameters
----------
obj_bin : bytearray
The object file
Return
------
cobj_bin : bytearray
The HSA Code Object
"""
tmp_dir = utils.tempdir()
tmp_obj = tmp_dir.relpath("rocm_kernel.o")
tmp_cobj = tmp_dir.relpath("rocm_kernel.co")
with open(tmp_obj, "wb") as out_file:
out_file.write(bytes(obj_bin))
rocm_link(tmp_obj, tmp_cobj)
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
return cobj_bin
@tvm._ffi.register_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes
Parameters
----------
rocdl_dir : str
The path to rocm library directory
The default value is the standard location
"""
# seems link order matters.
if rocdl_dir is None:
if exists("/opt/rocm/amdgcn/bitcode/"):
rocdl_dir = "/opt/rocm/amdgcn/bitcode/" # starting with rocm 3.9
else:
rocdl_dir = "/opt/rocm/lib/" # until rocm 3.8
bitcode_names = [
"oclc_daz_opt_on",
"ocml",
"irif", # this does not exist in rocm 3.9, drop eventually
"oclc_correctly_rounded_sqrt_off",
"oclc_correctly_rounded_sqrt_on",
"oclc_daz_opt_off",
"oclc_finite_only_off",
"oclc_finite_only_on",
# todo (t-vi): an alternative might be to scan for the
"oclc_isa_version_803",
"oclc_isa_version_900", # isa version files (if the linker throws out
"oclc_isa_version_906", # the unneeded ones or we filter for the arch we need)
"oclc_isa_version_1030",
"oclc_unsafe_math_off",
"oclc_unsafe_math_on",
"oclc_wavefrontsize64_on",
"oclc_abi_version_500",
]
bitcode_files = []
for n in bitcode_names:
p = join(rocdl_dir, n + ".bc") # rocm >= 3.9
if not exists(p): # rocm <= 3.8
p = join(rocdl_dir, n + ".amdgcn.bc")
if exists(p):
bitcode_files.append(p)
elif "isa_version" not in n and n not in {"irif"}:
raise RuntimeError("could not find bitcode " + n)
return tvm.runtime.convert(bitcode_files)
def parse_compute_version(compute_version):
"""Parse compute capability string to divide major and minor version
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "6.0")
Returns
-------
major : int
major version number
minor : int
minor version number
"""
split_ver = compute_version.split(".")
try:
major = int(split_ver[0])
minor = int(split_ver[1])
return major, minor
except (IndexError, ValueError) as err:
# pylint: disable=raise-missing-from
raise RuntimeError("Compute version parsing error: " + str(err))
def have_matrixcore(compute_version=None):
"""Either MatrixCore support is provided in the compute capability or not
Parameters
----------
compute_version : str, optional
compute capability of a GPU (e.g. "7.0").
Returns
-------
have_matrixcore : bool
True if MatrixCore support is provided, False otherwise
"""
if compute_version is None:
if tvm.rocm(0).exist:
compute_version = tvm.rocm(0).compute_version
else:
raise RuntimeError("No ROCm runtime found")
major, _ = parse_compute_version(compute_version)
# matrix core first introduced in 8.0
if major >= 8:
return True
return False
@tvm._ffi.register_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"):
"""Utility function to get the AMD GPU architecture
Parameters
----------
rocm_path : str
The path to rocm installation directory
Returns
-------
gpu_arch : str
The AMD GPU architecture
"""
gpu_arch = "gfx900"
# check if rocm is installed
if not os.path.exists(rocm_path):
print("ROCm not detected, using default gfx900")
return gpu_arch
try:
# Execute rocminfo command
rocminfo_output = subprocess.check_output([f"{rocm_path}/bin/rocminfo"]).decode("utf-8")
# Use regex to match the "Name" field
match = re.search(r"Name:\s+(gfx\d+[a-zA-Z]*)", rocminfo_output)
if match:
gpu_arch = match.group(1)
return gpu_arch
except subprocess.CalledProcessError:
print(f"Unable to execute rocminfo command, \
please ensure ROCm is installed and you have an AMD GPU on your system.\
using default {gpu_arch}.")
return gpu_arch
def find_rocm_path():
"""Utility function to find ROCm path
Returns
-------
path : str
Path to ROCm root.
"""
if "ROCM_PATH" in os.environ:
return os.environ["ROCM_PATH"]
cmd = ["which", "hipcc"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = out.decode("utf-8").strip()
if proc.returncode == 0:
return os.path.realpath(os.path.join(out, "../.."))
rocm_path = "/opt/rocm"
if os.path.exists(os.path.join(rocm_path, "bin/hipcc")):
return rocm_path
raise RuntimeError("Cannot find ROCm path")
from .lower import lower, is_device_call # noqa: F401 from .lower import lower, is_device_call # noqa: F401
from .param import KernelParam # noqa: F401 from .param import KernelParam # noqa: F401
from .callback import register_cuda_postproc, register_hip_postproc # noqa: F401
...@@ -9,6 +9,7 @@ import tempfile ...@@ -9,6 +9,7 @@ import tempfile
import subprocess import subprocess
import logging import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,11 +59,13 @@ class LibraryGenerator(object): ...@@ -58,11 +59,13 @@ class LibraryGenerator(object):
elif is_hip_target(target): elif is_hip_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so") libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path()
arch = get_rocm_arch(rocm_path)
command = [ command = [
"hipcc", "hipcc",
"-std=c++17", "-std=c++17",
"-fPIC", "-fPIC",
f"--offload-arch={arch}",
"--shared", "--shared",
src.name, src.name,
] ]
...@@ -84,7 +87,6 @@ class LibraryGenerator(object): ...@@ -84,7 +87,6 @@ class LibraryGenerator(object):
"-I" + TILELANG_TEMPLATE_PATH, "-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR, "-I" + CUTLASS_INCLUDE_DIR,
] ]
command += ["-diag-suppress=20013"]
command += ["-o", libpath] command += ["-o", libpath]
src.write(self.lib_code) src.write(self.lib_code)
......
...@@ -8,7 +8,7 @@ import re ...@@ -8,7 +8,7 @@ import re
import logging import logging
import textwrap import textwrap
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result_{0} != CUDA_SUCCESS) {{ if (result_{0} != CUDA_SUCCESS) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0})); snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
...@@ -16,6 +16,14 @@ PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ ...@@ -16,6 +16,14 @@ PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
}} }}
""" """
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP = """
hipError_t result_{0} = hipFuncSetAttribute((const void *){0}, hipFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result_{0} != HIP_SUCCESS) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, hipGetErrorString(result_{0}));
return -1;
}}
"""
PREDEF_INIT_FUNC = """ PREDEF_INIT_FUNC = """
#define ERROR_BUF_SIZE 1024 #define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE]; static char error_buf[ERROR_BUF_SIZE];
...@@ -159,7 +167,7 @@ class TLCUDASourceWrapper(object): ...@@ -159,7 +167,7 @@ class TLCUDASourceWrapper(object):
if dyn_sym not in [arg["name"] for arg in function_args]: if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "int"}) function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) function_args.append(self.get_stream_type())
# Format the function arguments for declaration # Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
...@@ -351,14 +359,14 @@ class TLCUDASourceWrapper(object): ...@@ -351,14 +359,14 @@ class TLCUDASourceWrapper(object):
dynamic_symbolic_set.append(dim.name) dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set return dynamic_symbolic_set
def get_cuda_init_func(self): def get_init_func(self):
# Initialize an empty string for the CUDA function call # Initialize an empty string for the CUDA function call
call_str = """""" call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None: if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory # Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format( call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(
function_name, dynamic_smem_buf) function_name, dynamic_smem_buf)
# Format the initialization function using the call_str # Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
...@@ -370,7 +378,7 @@ class TLCUDASourceWrapper(object): ...@@ -370,7 +378,7 @@ class TLCUDASourceWrapper(object):
# Get the function names # Get the function names
function_names = self.function_names function_names = self.function_names
# Get the CUDA initialization function # Get the CUDA initialization function
init_func = self.get_cuda_init_func() init_func = self.get_init_func()
# Organize function information for code generation # Organize function information for code generation
function_informations = {} function_informations = {}
...@@ -392,6 +400,9 @@ class TLCUDASourceWrapper(object): ...@@ -392,6 +400,9 @@ class TLCUDASourceWrapper(object):
lib_code = self.source + init_func + host_func lib_code = self.source + init_func + host_func
return lib_code return lib_code
def get_stream_type(self) -> Dict[str, str]:
return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"}
@property @property
def prim_func(self): def prim_func(self):
if len(self.mod.get_global_vars()) == 1: if len(self.mod.get_global_vars()) == 1:
...@@ -420,19 +431,21 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -420,19 +431,21 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_hip_init_func(self): def get_init_func(self):
# Initialize an empty string for the CUDA function call # Initialize an empty string for the CUDA function call
call_str = """""" call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None: for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
call_str = PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, if dynamic_smem_buf is not None:
self.dynamic_smem_buf) # Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(
function_name, dynamic_smem_buf)
# Format the initialization function using the call_str # Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs return init_funcs
def get_stream_type(self, function_args): def get_stream_type(self) -> Dict[str, str]:
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},) return {"name": "stream=hipStreamDefault", "type": "hipStream_t"}
class TLCPUSourceWrapper(object): class TLCPUSourceWrapper(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment