Unverified Commit 409ab83d authored by Tang Xinsheng's avatar Tang Xinsheng Committed by GitHub
Browse files

[AMD] support fp8 T.gemm (#804)



* [AMD] support fp8 T.gemm

* format

---------
Co-authored-by: default avatartangxinsheng.txs <tangxinsheng.txs@alibaba-inc.com>
parent b62a0b43
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools
def ref_program(A, B):
return (A.half() @ B.half().T).to(dtype=torch.float32)
def manual_check_prog(C, C_ref):
torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)
def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]
def get_configs():
block_Ms = [32, 64, 128]
block_Ns = [32, 64, 128]
block_Ks = [64, 128]
num_stages = [0]
num_threads = [256]
k_packs = [1, 2]
gemm_types = ["ss", "rs"]
valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
num_stages, num_threads, k_packs,
gemm_types):
valid_configs.append({
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
return valid_configs
@tilelang.autotune(
configs=get_configs(),
cache_input_tensors=True,
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
accum_dtype = "float"
@T.prim_func
def gemm_fp8_rs(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
@T.prim_func
def gemm_fp8_ss(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
if gemm_type == "ss":
return gemm_fp8_ss
elif gemm_type == "rs":
return gemm_fp8_rs
else:
raise ValueError(f"Invalid gemm_type: {gemm_type}")
def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("passed~")
if __name__ == "__main__":
test_gemm_fp8(512, 512, 512)
...@@ -59,21 +59,39 @@ From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator ...@@ -59,21 +59,39 @@ From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16 ./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
--detail-instruction --detail-instruction
*/ */
Fragment makeGemmFragmentAB16x16CDNA() { Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) {
IterVar i = make_itervar("i", 16); IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16 * k_pack);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4 * k_pack) + i;
PrimExpr index = FloorMod(j->var, 4 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) {
IterVar i = make_itervar("i", 16 * k_pack);
IterVar j = make_itervar("j", 16); IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1); IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i; PrimExpr forward_thread = 16 * FloorDiv(i->var, 4 * k_pack) + j;
PrimExpr index = FloorMod(j->var, 4); PrimExpr index = FloorMod(i->var, 4 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
Fragment makeGemmFragmentAB16x16CDNATransposed() { Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) {
IterVar i = make_itervar("i", 16); IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 32 * k_pack);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 8 * k_pack) + i;
PrimExpr index = FloorMod(j->var, 8 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) {
IterVar i = make_itervar("i", 32 * k_pack);
IterVar j = make_itervar("j", 16); IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1); IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(i->var, 4) + j; PrimExpr forward_thread = 16 * FloorDiv(i->var, 8 * k_pack) + j;
PrimExpr index = FloorMod(i->var, 4); PrimExpr index = FloorMod(i->var, 8 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
...@@ -224,27 +242,34 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ...@@ -224,27 +242,34 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m, const int block_k, const int warp_m,
const int warp_n, const int element_size, const int warp_n, const int element_size,
bool transposed) { const int k_pack, bool transposed) {
// assume not transposed // assume not transposed
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0); const int mfma_k = k_pack * (element_size == 16 ? 16 : 32);
ICHECK(block_k % mfma_k == 0);
ICHECK(element_size == 8 || element_size == 16) ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size; << "element bitwidth=" << element_size;
if (transposed) { if (transposed) {
auto base_layout = auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false); element_size == 16
? makeGemmFragmentAB16x16CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false)
: makeGemmFragmentAB16x32CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false);
auto warp_layout = auto warp_layout =
base_layout->Repeat({block_k / 16, warp_m / 16}, false, true); base_layout->Repeat({block_k / mfma_k, warp_m / 16}, false, true);
auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true) auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
->Replicate(block_n / warp_n); ->Replicate(block_n / warp_n);
return block_layout; return block_layout;
} else { } else {
auto base_layout = auto base_layout =
makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false); element_size == 16
? makeGemmFragmentAB16x16CDNA(k_pack)->Repeat({1, 1}, false, false)
: makeGemmFragmentAB16x32CDNA(k_pack)->Repeat({1, 1}, false, false);
auto warp_layout = auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false); base_layout->Repeat({warp_m / 16, block_k / mfma_k}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true) auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n); ->Replicate(block_n / warp_n);
return block_layout; return block_layout;
...@@ -397,7 +422,7 @@ Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, ...@@ -397,7 +422,7 @@ Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
const int numBanks = 32; const int numBanks = 32;
const int bankBitWidth = 32; const int bankBitWidth = 32;
const int SIMDWidth = 16; const int SIMDWidth = 16;
const int vecSize = 4 * kPack; const int vecSize = (64 / element_size) * kPack;
const int innerDimLength = continuous; const int innerDimLength = continuous;
const int typeWidthInBit = element_size; const int typeWidthInBit = element_size;
...@@ -616,12 +641,7 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, ...@@ -616,12 +641,7 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack) { int kPack) {
int vector_size = 128 / element_size; return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
if (continuous % (vector_size * 4) == 0)
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
else {
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -154,7 +154,7 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ...@@ -154,7 +154,7 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m, const int block_k, const int warp_m,
const int warp_n, const int element_size, const int warp_n, const int element_size,
bool transposed = false); const int k_pack, bool transposed = false);
// Default Memory Layout // Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmLayoutLinear(int stride, int continuous);
......
...@@ -582,7 +582,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -582,7 +582,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
results.Set(A, shared_layout); results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A); A->dtype.bits(), kPack, trans_A);
results.Set(A, fragment->BindThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
...@@ -594,10 +594,6 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -594,10 +594,6 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
*as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack); *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
results.Set(B, shared_layout); results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
......
...@@ -51,6 +51,19 @@ template <> struct MfmaTraits<bfloat16_t> { ...@@ -51,6 +51,19 @@ template <> struct MfmaTraits<bfloat16_t> {
} }
}; };
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
*c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
}
};
#endif
// ref to bitblas/tl/mfma_macro_generator.py::kPack // ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
bool TransposeB, bool clear_accum, int kPack, typename A_type, bool TransposeB, bool clear_accum, int kPack, typename A_type,
...@@ -61,7 +74,8 @@ public: ...@@ -61,7 +74,8 @@ public:
static constexpr int micro_size_x = 16; static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16; static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 16; static constexpr int micro_size_k = 32 / sizeof(A_type);
static constexpr int vec_size = 8 / sizeof(A_type);
// This part comes from the Codegen // This part comes from the Codegen
static constexpr int M_Tile = M; static constexpr int M_Tile = M;
...@@ -88,12 +102,12 @@ public: ...@@ -88,12 +102,12 @@ public:
TL_DEVICE static constexpr auto reverse_index_map(int thread_id, TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
int local_id) { int local_id) {
return std::make_pair(thread_id % 16, return std::make_pair(thread_id % 16,
(thread_id / 16) * (4 * kPack) + local_id); (thread_id / 16) * (vec_size * kPack) + local_id);
} }
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
int local_id) { int local_id) {
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id, return std::make_pair((thread_id / 16) * (vec_size * kPack) + local_id,
thread_id % 16); thread_id % 16);
} }
...@@ -108,7 +122,7 @@ public: ...@@ -108,7 +122,7 @@ public:
const int numBanks = 32; const int numBanks = 32;
const int bankBitWidth = 32; const int bankBitWidth = 32;
const int SIMDWidth = 16; const int SIMDWidth = 16;
const int vecSize = 4 * kPack; const int vecSize = vec_size * kPack;
const int innerDimLength = continuous; const int innerDimLength = continuous;
const int typeWidthInBit = dtype_bits; const int typeWidthInBit = dtype_bits;
...@@ -134,19 +148,9 @@ public: ...@@ -134,19 +148,9 @@ public:
template <int continuous = 32, int element_size = 2> template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row, TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
const int col) { const int col) {
constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8); auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
if (continuous % (vector_size * 4) == 0) { return n_row * continuous + n_col;
auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col;
} else {
auto [n_row, n_col] = make_layout_padded(row, col);
int padded = continuous;
if ((element_size * 8 * continuous) % 256 == 0)
padded += BANK_SIZE_BYTES / (element_size * 8);
return n_row * padded + n_col;
}
} }
static TL_DEVICE void body(A_type *A_shared, B_type *B_shared, static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
...@@ -213,11 +217,11 @@ public: ...@@ -213,11 +217,11 @@ public:
for (int i = 0; i < warp_rows; ++i) { for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) { for (int j = 0; j < warp_cols; ++j) {
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * 4; auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size;
auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * 4; auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * vec_size;
// Use the trait to select the correct MFMA instruction, either fp16 // Use the trait to select the correct MFMA instruction, either fp8,
// or bf16 currently // fp16 or bf16 currently
MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr); MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
} }
} }
...@@ -254,12 +258,12 @@ public: ...@@ -254,12 +258,12 @@ public:
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) { for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
if constexpr (TransposeB) { if constexpr (TransposeB) {
auto [row, col] = reverse_index_map(lane_id, local_id); auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * local_size_b + local_id] = B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]; l + row, r + col)];
} else { } else {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id); auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
B_local[j * local_size_b + local_id] = B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>( B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)]; r + row, l + col)];
} }
...@@ -271,12 +275,12 @@ public: ...@@ -271,12 +275,12 @@ public:
for (int i = 0; i < warp_rows; ++i) { for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) { for (int j = 0; j < warp_cols; ++j) {
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * 4; auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size;
auto a_ptr = ((A_type *)A_local) + auto a_ptr = ((A_type *)A_local) +
(ki * warp_rows * kPack + i * kPack + kp) * 4; (ki * warp_rows * kPack + i * kPack + kp) * vec_size;
// Use the trait to select the correct MFMA instruction, either fp16 // Use the trait to select the correct MFMA instruction, either fp8,
// or bf16 currently // fp16 or bf16 currently
MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr); MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
} }
} }
......
#include <hip/amd_detail/amd_hip_fp8.h> #include <hip/amd_detail/amd_hip_fp8.h>
#define HIP_FP8_ENABLED 1
using fp8_e4_t = __hip_fp8_e4m3_fnuz; using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz; using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
using fp8_e4_4_t = __hip_fp8x4_e4m3_fnuz;
// Simple wrapper that provides member access for generated code
struct fp8_e4_4_t {
union {
__hip_fp8x4_e4m3_fnuz data;
struct {
fp8_e4_t x, y, z, w;
};
};
// Default constructor
__device__ fp8_e4_4_t() = default;
// Constructor from __hip_fp8x4_e4m3_fnuz
__device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {}
// Constructor from float4
__device__ fp8_e4_4_t(const float4 &val) : data(val) {}
// Conversion operator to __hip_fp8x4_e4m3_fnuz
__device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; }
// Assignment operator
__device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) {
data = val;
return *this;
}
};
struct __align__(8) fp8_e4_8_t { struct __align__(8) fp8_e4_8_t {
fp8_e4_4_t x; fp8_e4_4_t x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment