Unverified Commit 9fd6bb30 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] support mfma i32_16x16x32_i8 (#800)


Co-authored-by: default avatarJiaxing Ding <jiaxing.ding@bytedance.com>
parent 54aaec98
...@@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "]" << ((i < 3) ? ", " : ")"); os << "]" << ((i < 3) ? ", " : ")");
} }
} else if (op->op.same_as(tl::tvm_mfma())) { } else if (op->op.same_as(tl::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype} // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col // arg 1: A layout: row/col
// arg 2: B layout: row/col // arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ... // arg 3: A precision: float16, float32, ...
...@@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"int8", "char"}, {"int8", "char"},
{"int32", "int"}, {"int32", "int"},
{"int8x4", "int32_t"}, {"int8x4", "int32_t"},
{"int8x8", "int64_t"},
{"int32x4", "int32x4"}, {"int32x4", "int32x4"},
{"float16", "half"}, {"float16", "half"},
{"float32", "float"}, {"float32", "float"},
...@@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnuzx8", "long"}, {"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}}; {"float32x16", "float32x16"}};
std::string call_mfma_code = R"({ std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}), *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}), *((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0); *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})"; })";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix; std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer; Replacer replacer;
replacer.register_rule("{mfma_buildin}", mfma_buildin); replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]); replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]); replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]); replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
replacer.register_rule("{a_ref}", a_ref); replacer.register_rule("{a_ref}", a_ref);
replacer.register_rule("{a_bias}", a_bias); replacer.register_rule("{a_bias}", a_bias);
replacer.register_rule("{b_ref}", b_ref); replacer.register_rule("{b_ref}", b_ref);
......
...@@ -8,6 +8,18 @@ namespace tl { ...@@ -8,6 +8,18 @@ namespace tl {
// Trait to determine the MFMA instruction to use based on data type // Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits; template <typename T> struct MfmaTraits;
// Specialization for int8
template <> struct MfmaTraits<int8_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
*c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0,
0);
}
};
// Specialization for half/float16 // Specialization for half/float16
template <> struct MfmaTraits<half> { template <> struct MfmaTraits<half> {
template <typename AccType> template <typename AccType>
......
...@@ -41,7 +41,9 @@ def tl_matmul( ...@@ -41,7 +41,9 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 32 warp_row_tiles = 32
warp_col_tiles = 32 warp_col_tiles = 32
chunk = 32
chunk = 32 * k_pack
shared_scope = "shared" shared_scope = "shared"
cache_write_shared = False cache_write_shared = False
...@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M, ...@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M,
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C) kernel(A, B, C)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -227,6 +230,9 @@ def test_assert_tl_matmul(): ...@@ -227,6 +230,9 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -81,7 +81,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -81,7 +81,7 @@ class MatrixCoreIntrinEmitter(object):
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz"]: if a_dtype in ["float8_e4m3fnuz", "int8"]:
self.k_dim = 32 self.k_dim = 32
return return
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
...@@ -123,6 +123,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -123,6 +123,8 @@ class MatrixCoreIntrinEmitter(object):
if in_dtype_abbrv == "fp8": if in_dtype_abbrv == "fp8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
else: else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment