Unverified Commit 9fd6bb30 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] support mfma i32_16x16x32_i8 (#800)


Co-authored-by: default avatarJiaxing Ding <jiaxing.ding@bytedance.com>
parent 54aaec98
......@@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(tl::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
......@@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"int8", "char"},
{"int32", "int"},
{"int8x4", "int32_t"},
{"int8x8", "int64_t"},
{"int32x4", "int32x4"},
{"float16", "half"},
{"float32", "float"},
......@@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}),
*((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0);
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;
replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]);
replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
replacer.register_rule("{a_ref}", a_ref);
replacer.register_rule("{a_bias}", a_bias);
replacer.register_rule("{b_ref}", b_ref);
......
......@@ -8,6 +8,18 @@ namespace tl {
// Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits;
// Specialization for int8
template <> struct MfmaTraits<int8_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
*c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0,
0);
}
};
// Specialization for half/float16
template <> struct MfmaTraits<half> {
template <typename AccType>
......
......@@ -41,7 +41,9 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32
chunk = 32 * k_pack
shared_scope = "shared"
cache_write_shared = False
......@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M,
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
......@@ -227,6 +230,9 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
if __name__ == "__main__":
......
......@@ -81,7 +81,7 @@ class MatrixCoreIntrinEmitter(object):
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz"]:
if a_dtype in ["float8_e4m3fnuz", "int8"]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype)
......@@ -123,6 +123,8 @@ class MatrixCoreIntrinEmitter(object):
if in_dtype_abbrv == "fp8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment