Commit 8f4628e0 authored by qisan's avatar qisan
Browse files

[Bugfix] Using a new data layout and the performance of NN gemm exceeds rocblas

parent 9a640856
...@@ -8,7 +8,8 @@ from tilelang.intrinsics.mmac_macro_generator import ( ...@@ -8,7 +8,8 @@ from tilelang.intrinsics.mmac_macro_generator import (
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang import disable_cache from tilelang import disable_cache
disable_cache() disable_cache()
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype dtype = shared_buf.dtype
...@@ -81,7 +82,7 @@ def tl_matmul( ...@@ -81,7 +82,7 @@ def tl_matmul(
threads = warp_size * (block_row_warps * block_col_warps) threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y warp_cols = warp_col_tiles // micro_size_y
...@@ -152,7 +153,7 @@ def tl_matmul( ...@@ -152,7 +153,7 @@ def tl_matmul(
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[ C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y, j // micro_size_y,
i // micro_size_x, i // micro_size_x,
i % micro_size_x, i % micro_size_x,
j % micro_size_y, j % micro_size_y,
] ]
...@@ -186,5 +187,3 @@ def main(): ...@@ -186,5 +187,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -157,8 +157,8 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, ...@@ -157,8 +157,8 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
} }
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size) { const int element_size) {
if (element_size == 64) if (element_size == 64)
LOG(FATAL) << "Not supported"; LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
...@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, ...@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
auto warp_layout = auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout = auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true); warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout; return block_layout;
} }
...@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, ...@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if (!k_inner && element_size == 8) // int8 KxN if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0) else if (mat_continuous % (vector_size * 8) == 0)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); // return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
// element_size);
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0) else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
......
...@@ -151,8 +151,8 @@ Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, ...@@ -151,8 +151,8 @@ Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
*/ */
#include "gemm.h" #include "gemm.h"
#include <fstream>
#include "builtin.h" #include "builtin.h"
#include <fstream>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
...@@ -828,15 +828,14 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -828,15 +828,14 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK(C.scope() == "local.fragment") ICHECK(C.scope() == "local.fragment")
<< "CDNA gemm (FMMA) only supports C in local.fragment scope, got " << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<< C.scope(); << C.scope();
if (TargetIsDCU(T.target)) if (TargetIsDCU(T.target)) {
{
auto fragment = auto fragment =
makeGemmFragmentCDCU(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCDCU(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
} else { } else {
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
} }
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
......
...@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl") ...@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl")
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_REGISTER_OP("tir.hip.__shfl_sync") TVM_REGISTER_OP("tir.hip.__shfl_sync")
.set_num_inputs(4) .set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.") .add_argument("mask", "Expr", "The thread mask.")
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "utils.h" #include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
#pragma once #pragma once
#include "core.hpp" #include "core.hpp"
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
...@@ -106,41 +105,36 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) { ...@@ -106,41 +105,36 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
template <typename T> template <typename T> struct is_half_type : std::false_type {};
struct is_half_type : std::false_type {};
template <> template <> struct is_half_type<__half> : std::true_type {};
struct is_half_type<__half> : std::true_type {};
template <> template <> struct is_half_type<half_t> : std::true_type {};
struct is_half_type<half_t> : std::true_type {};
template <typename T> template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value; inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
template <typename T1, typename T2> template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1* address, T2 val) { TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
if constexpr (is_half_v<T1>) { if constexpr (is_half_v<T1>) {
__half* addr = reinterpret_cast<__half*>(address); __half *addr = reinterpret_cast<__half *>(address);
__half hval = __float2half(static_cast<float>(val)); __half hval = __float2half(static_cast<float>(val));
atomicAdd(addr, hval); atomicAdd(addr, hval);
} else { } else {
atomicAdd(address, static_cast<T1>(val)); atomicAdd(address, static_cast<T1>(val));
} }
} }
template <typename T1, typename T2> template <typename T1, typename T2> TL_DEVICE void AtomicAdd(T1 &ref, T2 val) {
TL_DEVICE void AtomicAdd(T1 &ref, T2 val) { AtomicAdd(&ref, val);
AtomicAdd(&ref, val);
} }
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) { template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val)); return atomicAdd(&ref, static_cast<T1>(val));
} }
template <typename T> template <typename T> TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) {
TL_DEVICE void AtomicAddx4(T* ref, const T val[4]) { atomicAdd(&ref[0], val[0]);
atomicAdd(&ref[0], val[0]); atomicAdd(&ref[1], val[1]);
atomicAdd(&ref[1], val[1]); atomicAdd(&ref[2], val[2]);
atomicAdd(&ref[2], val[2]); atomicAdd(&ref[3], val[3]);
atomicAdd(&ref[3], val[3]);
} }
\ No newline at end of file
...@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, ...@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
} }
} // namespace tl } // namespace tl
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx9__) // for GPU code defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code #elif defined(__gfx103__) // for GPU code
...@@ -25,82 +25,53 @@ ...@@ -25,82 +25,53 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
namespace ck_tile{ namespace ck_tile {
using int32x4_t = int32_t __attribute__((ext_vector_type(4))); using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
template <typename T> template <typename T> CK_TILE_HOST_DEVICE constexpr T max(T x) { return x; }
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T> template <typename T> CK_TILE_HOST constexpr T max(T x, T y) {
CK_TILE_HOST constexpr T max(T x, T y) return x > y ? x : y;
{
return x > y ? x : y;
} }
template <typename T> template <typename T> CK_TILE_DEVICE constexpr T max(T x, T y) {
CK_TILE_DEVICE constexpr T max(T x, T y) return x > y ? x : y;
{
return x > y ? x : y;
} }
template <> template <> CK_TILE_DEVICE float max(float x, float y) {
CK_TILE_DEVICE float max(float x, float y) return __builtin_fmaxf(x, y); // can resultin v_max3_f32
{
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
} }
template <> template <> CK_TILE_DEVICE double max(double x, double y) {
CK_TILE_DEVICE double max(double x, double y) return __builtin_fmax(x, y); // maybe still v_max3_f32
{
return __builtin_fmax(x, y); // maybe still v_max3_f32
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) {
{ static_assert(sizeof...(Ys) > 0, "not enough argument");
static_assert(sizeof...(Ys) > 0, "not enough argument"); return max(x, max(ys...));
return max(x, max(ys...));
} }
template <typename T> template <typename T> CK_TILE_HOST_DEVICE constexpr T min(T x) { return x; }
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T> template <typename T> CK_TILE_HOST constexpr T min(T x, T y) {
CK_TILE_HOST constexpr T min(T x, T y) return x < y ? x : y;
{
return x < y ? x : y;
} }
template <typename T> template <typename T> CK_TILE_DEVICE constexpr T min(T x, T y) {
CK_TILE_DEVICE constexpr T min(T x, T y) return x < y ? x : y;
{
return x < y ? x : y;
} }
template <> template <> CK_TILE_DEVICE float min(float x, float y) {
CK_TILE_DEVICE float min(float x, float y) return __builtin_fminf(x, y);
{
return __builtin_fminf(x, y);
} }
template <> template <> CK_TILE_DEVICE double min(double x, double y) {
CK_TILE_DEVICE double min(double x, double y) return __builtin_fmin(x, y);
{
return __builtin_fmin(x, y);
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) {
{ static_assert(sizeof...(Ys) > 0, "not enough argument");
static_assert(sizeof...(Ys) > 0, "not enough argument"); return min(x, min(ys...));
return min(x, min(ys...));
}
} }
} // namespace ck_tile
...@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg, ...@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var); index, var);
} }
...@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA, ...@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
typename B_type, typename C_type, typename AccDataType = float> typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp { class GemmTensorOp {
public: public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet"); // static_assert(!clear_accum, "clear_accum=true is not supported yet");
static constexpr int micro_size_x = 16; static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16; static constexpr int micro_size_y = 16;
...@@ -156,8 +156,8 @@ public: ...@@ -156,8 +156,8 @@ public:
C_type *C_local) { C_type *C_local) {
auto tid = threadIdx.x; auto tid = threadIdx.x;
auto warp_id = tid / warp_size; auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps; auto warp_m = warp_id / block_col_warps;
auto warp_m = warp_id % block_row_warps; auto warp_n = warp_id % block_col_warps;
auto warp_row_tiles = warp_rows * micro_size_x; auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y; auto warp_col_tiles = warp_cols * micro_size_y;
...@@ -165,8 +165,8 @@ public: ...@@ -165,8 +165,8 @@ public:
auto tx = lane_id; auto tx = lane_id;
auto alane_id = lane_id; auto alane_id = lane_id;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4); auto blane_id =
((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
...@@ -186,15 +186,14 @@ public: ...@@ -186,15 +186,14 @@ public:
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
if constexpr (TransposeB) { if constexpr (TransposeB) {
auto [row, col] = reverse_index_map(blane_id, local_id); auto [row, col] = reverse_index_map(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] = B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]; l + row, r + col)];
} else { } else {
auto [row, col] = reverse_index_map_transposed(blane_id, local_id); auto [row, col] = reverse_index_map_transposed(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] = B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)]; r + row, l + col)];
} }
} }
} }
...@@ -205,12 +204,12 @@ public: ...@@ -205,12 +204,12 @@ public:
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
if constexpr (TransposeA) { if constexpr (TransposeA) {
auto [row, col] = reverse_index_map_transposed(alane_id, local_id); auto [row, col] = reverse_index_map_transposed(alane_id, local_id);
A_local[j * kPack * local_size_a + local_id] = A_local[j * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>( A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
r + row, l + col)]; r + row, l + col)];
} else { } else {
auto [row, col] = reverse_index_map(alane_id, local_id); auto [row, col] = reverse_index_map(alane_id, local_id);
A_local[j * kPack * local_size_a + local_id] = A_local[j * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>( A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)]; l + row, r + col)];
} }
...@@ -237,8 +236,8 @@ public: ...@@ -237,8 +236,8 @@ public:
C_type *C_local) { C_type *C_local) {
auto tid = threadIdx.x; auto tid = threadIdx.x;
auto warp_id = tid / warp_size; auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps; auto warp_m = warp_id / block_col_warps;
auto warp_m = warp_id % block_row_warps; auto warp_n = warp_id % block_col_warps;
auto warp_row_tiles = warp_rows * micro_size_x; auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y; auto warp_col_tiles = warp_cols * micro_size_y;
...@@ -246,7 +245,8 @@ public: ...@@ -246,7 +245,8 @@ public:
auto tx = lane_id; auto tx = lane_id;
auto alane_id = lane_id; auto alane_id = lane_id;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4); auto blane_id =
((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
...@@ -265,12 +265,12 @@ public: ...@@ -265,12 +265,12 @@ public:
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
if constexpr (TransposeB) { if constexpr (TransposeB) {
auto [row, col] = reverse_index_map(blane_id, local_id); auto [row, col] = reverse_index_map(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] = B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]; l + row, r + col)];
} else { } else {
auto [row, col] = reverse_index_map_transposed(blane_id, local_id); auto [row, col] = reverse_index_map_transposed(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] = B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)]; r + row, l + col)];
} }
...@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { ...@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
} }
} // namespace tl } // namespace tl
...@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, ...@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b); res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res; return res;
} }
...@@ -22,14 +22,11 @@ struct MinOp { ...@@ -22,14 +22,11 @@ struct MinOp {
} }
}; };
// Detect half types // Detect half types
template <typename T> template <typename T> struct is_half_type : std::false_type {};
struct is_half_type : std::false_type {};
template <> template <> struct is_half_type<__half> : std::true_type {};
struct is_half_type<__half> : std::true_type {};
template <> template <> struct is_half_type<_Float16> : std::true_type {};
struct is_half_type<_Float16> : std::true_type {};
template <typename T> template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value; inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
...@@ -56,7 +53,10 @@ struct AllReduce { ...@@ -56,7 +53,10 @@ struct AllReduce {
if constexpr (std::is_same_v<std::decay_t<T>, __half>) { if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
x_raw = __half_as_ushort(x); x_raw = __half_as_ushort(x);
} else { // _Float16 } else { // _Float16
union { _Float16 f; unsigned short s; } u; union {
_Float16 f;
unsigned short s;
} u;
u.f = x; u.f = x;
x_raw = u.s; x_raw = u.s;
} }
...@@ -67,7 +67,10 @@ struct AllReduce { ...@@ -67,7 +67,10 @@ struct AllReduce {
if constexpr (std::is_same_v<std::decay_t<T>, __half>) { if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
shuffled_x = __ushort_as_half(shuffled_raw); shuffled_x = __ushort_as_half(shuffled_raw);
} else { // _Float16 } else { // _Float16
union { unsigned short s; _Float16 f; } u; union {
unsigned short s;
_Float16 f;
} u;
u.s = shuffled_raw; u.s = shuffled_raw;
shuffled_x = u.f; shuffled_x = u.f;
} }
...@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T val = (col < W) ? src[real_row * W + real_col] : (T)0; T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off); T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off) if (lane < SEG - off)
...@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T val = (col < W) ? src[real_row * W + real_col] : (T)0; T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off); T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off) if (lane >= off)
...@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
} }
}; };
} // namespace tl } // namespace tl
...@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() { ...@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
} }
} // namespace tl } // namespace tl
...@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func ...@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
tilelang.disable_cache() tilelang.disable_cache()
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype dtype = shared_buf.dtype
shape = shared_buf.shape shape = shared_buf.shape
...@@ -63,7 +64,7 @@ def tl_matmul( ...@@ -63,7 +64,7 @@ def tl_matmul(
chunk = 32 * k_pack chunk = 32 * k_pack
shared_scope = "shared" shared_scope = "shared"
cache_write_shared = False # cache_write_shared = False
block_M = block_row_warps * warp_row_tiles block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles block_N = block_col_warps * warp_col_tiles
...@@ -171,7 +172,7 @@ def tl_matmul( ...@@ -171,7 +172,7 @@ def tl_matmul(
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[ C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y, j // micro_size_y,
i // micro_size_x, i // micro_size_x,
i % micro_size_x, i % micro_size_x,
j % micro_size_y, j % micro_size_y,
] ]
......
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from typing import Tuple
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
from tvm.runtime import convert from tvm.runtime import convert
from typing import Optional
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,)
lift = convert lift = convert
class MatrixCoreIntrinEmitter(object): class MatrixCoreIntrinEmitter:
""" """
To eliminate Python syntax within TIR Macro. To eliminate Python syntax within TIR Macro.
""" """
...@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object): ...@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
k_pack: Optional[int] = None, k_pack: int | None = None,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
b_preshuffle: Optional[bool] = False, b_preshuffle: bool | None = False,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object):
"float16": "f16", "float16": "f16",
"float32": "f32", "float32": "f32",
"int8": "i8", "int8": "i8",
"bfloat16" : "bf16" "bfloat16": "bf16"
}[in_dtype] }[in_dtype]
self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
...@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object): ...@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object):
self.micro_size_y = n_dim self.micro_size_y = n_dim
self.micro_size_k = k_dim self.micro_size_k = k_dim
def _initialize_k_pack(self, k_pack: Optional[int] = None): def _initialize_k_pack(self, k_pack: int | None = None):
if k_pack is not None: if k_pack is not None:
self.k_pack = k_pack self.k_pack = k_pack
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None: if is_m_first is not None:
self.is_m_first = is_m_first self.is_m_first = is_m_first
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False):
if b_preshuffle is not None: if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
...@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object):
def extract_thread_binding(self, def extract_thread_binding(self,
thread_id, thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
''' '''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object): ...@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object):
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) row, col = T.meta_var(
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = ( l, r = (
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
...@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object): ...@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object):
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) row, col = T.meta_var(
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = ( l, r = (
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
...@@ -372,13 +375,13 @@ class MatrixCoreIntrinEmitter(object): ...@@ -372,13 +375,13 @@ class MatrixCoreIntrinEmitter(object):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * N_DIM + (warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[j * (warp_rows * local_size_out) + col] = C_local_buf[j * (warp_rows * local_size_out) +
i * local_size_out + local_id] i * local_size_out + local_id]
else: else:
C_buf[warp_n * warp_cols + j, warp_m * warp_rows + i, row, C_buf[warp_n * warp_cols + j, warp_m * warp_rows + i, row,
col] = C_local_buf[j * warp_rows * local_size_out + col] = C_local_buf[j * warp_rows * local_size_out +
i * local_size_out + local_id] i * local_size_out + local_id]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
...@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
k_pack: Optional[int] = None, k_pack: int | None = None,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
a_preshuffle: Optional[bool] = False, a_preshuffle: bool | None = False,
b_preshuffle: Optional[bool] = False, b_preshuffle: bool | None = False,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
...@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id)) row, col = T.meta_var(
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = ( l, r = (
warp_n * warp_cols + j, warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
...@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id)) row, col = T.meta_var(
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = ( l, r = (
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j, warp_n * warp_cols + j,
...@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_b_shared( rk) if is_global else _warp_ldmatrix_b_shared(
B_local_buf, B_buf, ki, thread_binding, rk) B_local_buf, B_buf, ki, thread_binding, rk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment