Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
9a640856
Commit
9a640856
authored
Oct 28, 2025
by
qisan
Browse files
[Feature] Add some testing files for Hygon DCU
parent
2c490782
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
267 additions
and
10 deletions
+267
-10
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+0
-1
src/tl_templates/dcu_hip/gemm.h
src/tl_templates/dcu_hip/gemm.h
+2
-2
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
+260
-0
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+5
-7
No files found.
src/target/codegen_hip.cc
View file @
9a640856
...
@@ -1383,7 +1383,6 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
...
@@ -1383,7 +1383,6 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
CodeGenC
::
PrintType
(
f
->
ret_type
,
stream
);
CodeGenC
::
PrintType
(
f
->
ret_type
,
stream
);
this
->
PrintExtraAttrs
(
f
,
stream
);
this
->
PrintExtraAttrs
(
f
,
stream
);
this
->
stream
<<
" "
<<
static_cast
<
std
::
string
>
(
global_symbol
.
value
())
<<
"("
;
this
->
stream
<<
" "
<<
static_cast
<
std
::
string
>
(
global_symbol
.
value
())
<<
"("
;
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
tir
::
Var
v
=
f
->
params
[
i
];
tir
::
Var
v
=
f
->
params
[
i
];
std
::
string
vid
=
AllocVarID
(
v
.
get
());
std
::
string
vid
=
AllocVarID
(
v
.
get
());
...
...
src/tl_templates/dcu_hip/gemm.h
View file @
9a640856
...
@@ -165,7 +165,7 @@ public:
...
@@ -165,7 +165,7 @@ public:
auto
tx
=
lane_id
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
(
lane_id
&
15
)
/
4
+
(
lane_id
&
3
)
*
4
+
(
lane_id
/
16
)
*
16
;
auto
blane_id
=
(
(
lane_id
&
15
)
>>
2
)
+
(
(
lane_id
&
3
)
<<
2
)
+
(
(
lane_id
>>
4
)
<<
4
)
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
...
@@ -246,7 +246,7 @@ public:
...
@@ -246,7 +246,7 @@ public:
auto
tx
=
lane_id
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
(
lane_id
&
15
)
/
4
+
(
lane_id
&
3
)
*
4
+
(
lane_id
/
16
)
*
16
;
auto
blane_id
=
(
(
lane_id
&
15
)
>>
2
)
+
(
(
lane_id
&
3
)
<<
2
)
+
(
(
lane_id
>>
4
)
<<
4
)
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
...
...
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
0 → 100644
View file @
9a640856
import
torch
import
tilelang.testing
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
import
tilelang.language
as
T
# from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mmac_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
can_swizzle
=
shape
[
-
1
]
*
DataType
(
dtype
).
bits
==
512
if
not
can_swizzle
:
return
T
.
Layout
(
shape
,
lambda
*
args
:
args
)
def
transform_func
(
i
,
j
):
new_warp_i
,
new_warp_j
=
get_swizzle_layout
(
i
,
j
,
shape
[
-
1
],
dtype
)
return
[
new_warp_i
,
new_warp_j
]
return
T
.
Layout
(
shape
,
transform_func
)
@
simplify_prim_func
def
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
):
assert
in_dtype
in
[
"float16"
,
"bfloat16"
,
"int8"
,
],
"Currently only float16, bfloat16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
micro_size_k
=
32
block_row_warps
=
2
block_col_warps
=
2
warp_row_tiles
=
32
warp_col_tiles
=
32
chunk
=
32
*
k_pack
shared_scope
=
"shared"
cache_write_shared
=
False
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
block_K
=
chunk
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
C_shared_shape
=
(
block_M
//
micro_size_x
,
block_N
//
micro_size_y
,
micro_size_x
,
micro_size_y
,
)
warp_size
=
64
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
local_size_a
=
(
k_pack
*
micro_size_x
*
micro_size_k
)
//
warp_size
local_size_b
=
(
k_pack
*
micro_size_y
*
micro_size_k
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
warp_rows
=
warp_row_tiles
//
micro_size_x
warp_cols
=
warp_col_tiles
//
micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mmac_emitter
=
MatrixCoreIntrinEmitter
(
a_dtype
=
in_dtype
,
b_dtype
=
in_dtype
,
accum_dtype
=
accum_dtype
,
a_transposed
=
a_transposed
,
b_transposed
=
b_transposed
,
block_row_warps
=
block_row_warps
,
block_col_warps
=
block_col_warps
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
chunk
,
k_pack
=
k_pack
,
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Load B into shared memory
if
b_transposed
:
T
.
copy
(
B
[
bx
*
block_N
,
ko
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
# Load A into fragment
mmac_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Load B into fragment
mmac_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
,
)
# Perform Matrix Multiplication
mmac_emitter
.
mmac
(
A_local
,
B_local
,
C_local
)
# Perform STMatrix
mmac_emitter
.
stmatrix
(
C_local
,
C_shared
,
)
# Store shared into global
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
j
//
micro_size_y
,
i
//
micro_size_x
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
return
main
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
# src_code is the generated cuda source
assert
src_code
is
not
None
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
C
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
out_dtype
))
kernel
(
A
,
B
,
C
)
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
# Ensure that the latency is not None
assert
latency
is
not
None
if
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
print
(
C
)
print
(
ref_c
)
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float16"
,
"float16"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float16"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float16"
,
"float32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
# assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
# assert_tl_matmul_correctness(
# 128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/intrinsics/mmac_macro_generator.py
View file @
9a640856
...
@@ -118,12 +118,10 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -118,12 +118,10 @@ class MatrixCoreIntrinEmitter(object):
in_dtype_abbrv
=
{
in_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
"int8"
:
"i8"
,
"bfloat16"
:
"bf16"
}[
in_dtype
]
}[
in_dtype
]
if
in_dtype_abbrv
==
"i8"
:
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_i8"
else
:
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
def
_initialize_micro_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
):
def
_initialize_micro_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
):
...
@@ -581,7 +579,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -581,7 +579,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if
is_transposed
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
/
4
+
(
tx
&
3
)
*
4
+
(
tx
/
16
)
*
16
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
(
(
tx
&
3
)
<<
2
)
+
(
(
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
...
@@ -591,7 +589,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -591,7 +589,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
/
4
+
(
tx
&
3
)
*
4
+
(
tx
/
16
)
*
16
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
(
(
tx
&
3
)
<<
2
)
+
(
(
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment