Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
9fd6bb30
Unverified
Commit
9fd6bb30
authored
Sep 10, 2025
by
Jiaxing Ding
Committed by
GitHub
Sep 10, 2025
Browse files
[AMD] support mfma i32_16x16x32_i8 (#800)
Co-authored-by:
Jiaxing Ding
<
jiaxing.ding@bytedance.com
>
parent
54aaec98
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
9 deletions
+30
-9
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+8
-7
src/tl_templates/hip/gemm.h
src/tl_templates/hip/gemm.h
+12
-0
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+7
-1
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+3
-1
No files found.
src/target/codegen_hip.cc
View file @
9fd6bb30
...
...
@@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os
<<
"]"
<<
((
i
<
3
)
?
", "
:
")"
);
}
}
else
if
(
op
->
op
.
same_as
(
tl
::
tvm_mfma
()))
{
// arg 0: prefix: {otype}_
16x16x16
{itype}
// arg 0: prefix: {otype}_
{intrM}x{intrN}x{intrK}_
{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
...
...
@@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{
"int8"
,
"char"
},
{
"int32"
,
"int"
},
{
"int8x4"
,
"int32_t"
},
{
"int8x8"
,
"int64_t"
},
{
"int32x4"
,
"int32x4"
},
{
"float16"
,
"half"
},
{
"float32"
,
"float"
},
...
...
@@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{
"float8_e4m3fnuzx8"
,
"long"
},
{
"float32x16"
,
"float32x16"
}};
std
::
string
call_mfma_code
=
R"({
*((({C_d
y
tpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_d
y
tpe}*){a_ref}) + {a_bias}),
*((({B_d
y
tpe}*){b_ref}) + {b_bias}),
*((({C_d
y
tpe}*){c_ref}) + {c_bias}), 0, 0, 0);
*((({C_dt
y
pe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dt
y
pe}*){a_ref}) + {a_bias}),
*((({B_dt
y
pe}*){b_ref}) + {b_bias}),
*((({C_dt
y
pe}*){c_ref}) + {c_bias}), 0, 0, 0);
})"
;
std
::
string
mfma_buildin
=
"__builtin_amdgcn_mfma_"
+
prefix
;
Replacer
replacer
;
replacer
.
register_rule
(
"{mfma_buildin}"
,
mfma_buildin
);
replacer
.
register_rule
(
"{A_d
y
tpe}"
,
dtype_map
[
A_dtype
]);
replacer
.
register_rule
(
"{B_d
y
tpe}"
,
dtype_map
[
B_dtype
]);
replacer
.
register_rule
(
"{C_d
y
tpe}"
,
dtype_map
[
C_dtype
]);
replacer
.
register_rule
(
"{A_dt
y
pe}"
,
dtype_map
[
A_dtype
]);
replacer
.
register_rule
(
"{B_dt
y
pe}"
,
dtype_map
[
B_dtype
]);
replacer
.
register_rule
(
"{C_dt
y
pe}"
,
dtype_map
[
C_dtype
]);
replacer
.
register_rule
(
"{a_ref}"
,
a_ref
);
replacer
.
register_rule
(
"{a_bias}"
,
a_bias
);
replacer
.
register_rule
(
"{b_ref}"
,
b_ref
);
...
...
src/tl_templates/hip/gemm.h
View file @
9fd6bb30
...
...
@@ -8,6 +8,18 @@ namespace tl {
// Trait to determine the MFMA instruction to use based on data type
template
<
typename
T
>
struct
MfmaTraits
;
// Specialization for int8
template
<
>
struct
MfmaTraits
<
int8_t
>
{
template
<
typename
AccType
>
static
TL_DEVICE
void
mfma_op
(
const
int8_t
*
b
,
const
int8_t
*
a
,
AccType
*
c
)
{
int64_t
*
b_packed
=
reinterpret_cast
<
int64_t
*>
(
const_cast
<
int8_t
*>
(
b
));
int64_t
*
a_packed
=
reinterpret_cast
<
int64_t
*>
(
const_cast
<
int8_t
*>
(
a
));
*
c
=
__builtin_amdgcn_mfma_i32_16x16x32_i8
(
*
b_packed
,
*
a_packed
,
*
c
,
0
,
0
,
0
);
}
};
// Specialization for half/float16
template
<
>
struct
MfmaTraits
<
half
>
{
template
<
typename
AccType
>
...
...
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
9fd6bb30
...
...
@@ -41,7 +41,9 @@ def tl_matmul(
block_col_warps
=
2
warp_row_tiles
=
32
warp_col_tiles
=
32
chunk
=
32
chunk
=
32
*
k_pack
shared_scope
=
"shared"
cache_write_shared
=
False
...
...
@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M,
C
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
out_dtype
))
kernel
(
A
,
B
,
C
)
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
()
...
...
@@ -227,6 +230,9 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float16"
,
"float16"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float16"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float16"
,
"float32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
if
__name__
==
"__main__"
:
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
9fd6bb30
...
...
@@ -81,7 +81,7 @@ class MatrixCoreIntrinEmitter(object):
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
if
a_dtype
in
[
"float8_e4m3fnuz"
]:
if
a_dtype
in
[
"float8_e4m3fnuz"
,
"int8"
]:
self
.
k_dim
=
32
return
a_dtype
=
DataType
(
a_dtype
)
...
...
@@ -123,6 +123,8 @@ class MatrixCoreIntrinEmitter(object):
if
in_dtype_abbrv
==
"fp8"
:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_fp8_fp8"
elif
in_dtype_abbrv
==
"i8"
:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_i8"
else
:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment