Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
343
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
799 additions
and
504 deletions
+799
-504
examples/gemm/example_gemm_intrinsics.py
examples/gemm/example_gemm_intrinsics.py
+19
-19
examples/gemm/example_gemm_persistent.py
examples/gemm/example_gemm_persistent.py
+20
-43
examples/gemm/example_gemm_schedule.py
examples/gemm/example_gemm_schedule.py
+4
-5
examples/gemm_fp8/example_tilelang_gemm_amd.py
examples/gemm_fp8/example_tilelang_gemm_amd.py
+30
-51
examples/gemm_fp8/example_tilelang_gemm_fp8.py
examples/gemm_fp8/example_tilelang_gemm_fp8.py
+9
-11
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
+10
-11
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+28
-23
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
+124
-0
examples/gemm_sm100/README.md
examples/gemm_sm100/README.md
+8
-8
examples/gemm_sm100/gemm_mma.py
examples/gemm_sm100/gemm_mma.py
+6
-6
examples/gemm_sm100/gemm_tcgen5mma.py
examples/gemm_sm100/gemm_tcgen5mma.py
+9
-17
examples/gemm_sp/example_custom_compress.py
examples/gemm_sp/example_custom_compress.py
+336
-0
examples/gemm_sp/example_gemm_sp.py
examples/gemm_sp/example_gemm_sp.py
+60
-79
examples/gemm_sp/test_example_gemm_sp.py
examples/gemm_sp/test_example_gemm_sp.py
+16
-0
examples/gemm_splitk/example_tilelang_gemm_splitk.py
examples/gemm_splitk/example_tilelang_gemm_splitk.py
+5
-16
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
...plitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
+5
-16
examples/gemm_streamk/example_tilelang_gemm_streamk.py
examples/gemm_streamk/example_tilelang_gemm_streamk.py
+9
-11
examples/gemv/example_gemv.py
examples/gemv/example_gemv.py
+58
-68
examples/gemv/test_example_gemv.py
examples/gemv/test_example_gemv.py
+1
-3
examples/grouped_gemm/example_grouped_gemm_bwd.py
examples/grouped_gemm/example_grouped_gemm_bwd.py
+42
-117
No files found.
Too many changes to show.
To preserve performance only
343 of 343+
files are displayed.
Plain diff
Email patch
examples/gemm/example_gemm_intrinsics.py
View file @
667632cc
...
...
@@ -4,7 +4,8 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
...
...
@@ -34,18 +35,18 @@ def tl_matmul(
accum_dtype
,
):
assert
in_dtype
in
[
"
float16
"
,
"
int8
"
,
T
.
float16
,
T
.
int8
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"
float16
"
,
"
float32
"
,
"
int32
"
,
T
.
float16
,
T
.
float32
,
T
.
int32
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
out_dtype
==
"
int32
"
:
if
out_dtype
==
T
.
int32
:
micro_size_k
=
32
# This is a debug config
...
...
@@ -53,7 +54,7 @@ def tl_matmul(
block_col_warps
=
2
warp_row_tiles
=
64
warp_col_tiles
=
64
# chunk = 32 if in_dtype ==
"
float16
"
else 64
# chunk = 32 if in_dtype ==
T.
float16 else 64
chunk
=
32
shared_scope
=
"shared.dyn"
...
...
@@ -104,7 +105,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -112,10 +112,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -123,7 +125,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -133,7 +134,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
@@ -163,7 +163,7 @@ def ref_program(A, B):
def
main
(
M
=
4096
,
N
=
4096
,
K
=
4096
):
in_dtype
,
out_dtype
,
accum_dtype
=
"
float16
"
,
"
float16
"
,
"
float32
"
in_dtype
,
out_dtype
,
accum_dtype
=
T
.
float16
,
T
.
float16
,
T
.
float32
kernel
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
)
src_code
=
kernel
.
get_kernel_source
()
# src_code is the generated cuda source
...
...
examples/gemm/example_gemm_persistent.py
View file @
667632cc
...
...
@@ -5,17 +5,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
,
use_persistent_primitive
=
True
):
sm_num
=
driver
.
get_num_sms
()
m_blocks
=
T
.
ceildiv
(
M
,
block_M
)
n_blocks
=
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
for
bx
,
by
in
T
.
Persistent
(
[
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
for
bx
,
by
in
T
.
Persistent
([
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
bx
*
block_M
,
k
*
block_K
],
A_shared
)
...
...
@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages
=
3
persistent_kernel
=
matmul_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Persistent GEMM: All check passed."
)
persistent_latency
=
persistent_profiler
.
do_bench
(
warmup
=
500
)
print
(
f
"Persistent GEMM Latency:
{
persistent_latency
}
ms"
)
print
(
f
"Persistent GEMM TFlops:
{
total_flops
/
persistent_latency
*
1e-9
}
TFlops"
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Non-Persistent GEMM: All check passed."
)
non_persistent_latency
=
non_persistent_profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--M
'
,
type
=
int
,
default
=
8192
,
help
=
'
M dimension
'
)
parser
.
add_argument
(
'
--N
'
,
type
=
int
,
default
=
8192
,
help
=
'
N dimension
'
)
parser
.
add_argument
(
'
--K
'
,
type
=
int
,
default
=
8192
,
help
=
'
K dimension
'
)
parser
.
add_argument
(
"
--M
"
,
type
=
int
,
default
=
8192
,
help
=
"
M dimension
"
)
parser
.
add_argument
(
"
--N
"
,
type
=
int
,
default
=
8192
,
help
=
"
N dimension
"
)
parser
.
add_argument
(
"
--K
"
,
type
=
int
,
default
=
8192
,
help
=
"
K dimension
"
)
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
M
,
args
.
N
,
args
.
K
main
(
M
,
N
,
K
)
examples/gemm/example_gemm_schedule.py
View file @
667632cc
...
...
@@ -3,8 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm_schedule
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/gemm_fp8/example_tilelang_gemm_amd.py
View file @
667632cc
...
...
@@ -17,10 +17,8 @@ def supply_prog(args):
a_param
,
b_param
=
args
M
,
K
=
a_param
.
shape
N
,
_
=
b_param
.
shape
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
return
[
a
,
b
]
...
...
@@ -35,10 +33,9 @@ def get_configs():
valid_configs
=
[]
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
({
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_N"
:
n
,
"block_K"
:
k
,
...
...
@@ -46,20 +43,18 @@ def get_configs():
"num_threads"
:
t
,
"k_pack"
:
kp
,
"gemm_type"
:
gemm_type
,
})
}
)
return
valid_configs
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
dtype
=
"
float8_e4m3fnuz
"
accum_dtype
=
"
float
"
dtype
=
T
.
float8_e4m3fnuz
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
gemm_fp8_rs
(
...
...
@@ -67,8 +62,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -77,13 +71,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -93,8 +81,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def
test_gemm_fp8
(
M
,
N
,
K
):
kernel
=
fp8_matmul
(
M
,
N
,
K
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8.py
View file @
667632cc
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
map_torch_type
def
calc_diff
(
x
,
y
):
...
...
@@ -12,8 +11,7 @@ def calc_diff(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm_fp8
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -37,12 +35,12 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def
test_gemm_fp8
(
M
,
N
,
K
,
dtype
):
torch_dtype
=
map_torch_
type
(
dtype
)
torch_dtype
=
T
.
d
type
(
dtype
)
.
as_torch
()
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
...
...
@@ -57,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
T
.
float8_e4m3
fn
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
T
.
float8_e5m2
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
View file @
667632cc
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
map_torch_type
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"
float
"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
T
.
float
32
):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
...
...
@@ -55,18 +54,18 @@ def calc_diff(x, y):
def
test_gemm_fp8
(
M
,
N
,
K
,
dtype
):
torch_dtype
=
map_torch_
type
(
dtype
)
torch_dtype
=
T
.
d
type
(
dtype
)
.
as_torch
()
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
float
()
@
b
.
float
().
T
)
ref_c
=
a
.
float
()
@
b
.
float
().
T
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
...
...
@@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
T
.
float8_e4m3
fn
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
T
.
float8_e5m2
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
View file @
667632cc
...
...
@@ -5,7 +5,8 @@ from tvm import DataType
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
...
...
@@ -38,21 +39,26 @@ def tl_matmul(
accum_dtype
,
):
assert
in_dtype
in
[
"
float16
"
,
"
float8_e4m3
"
,
"
float8_e5m2
"
,
"
int8
"
,
T
.
float16
,
T
.
float8_e4m3
fn
,
T
.
float8_e5m2
,
T
.
int8
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"
float16
"
,
"
float32
"
,
"
int32
"
,
T
.
float16
,
T
.
float32
,
T
.
int32
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
is_float8
=
in_dtype
in
[
"float8_e4m3"
,
"float8_e5m2"
]
if
out_dtype
==
"int32"
or
is_float8
:
is_float8
=
in_dtype
in
[
T
.
float8_e4m3fn
,
T
.
float8_e5m2
,
T
.
float8_e4m3fn
,
T
.
float8_e5m2fnuz
,
]
if
out_dtype
==
T
.
int32
or
is_float8
:
micro_size_k
=
32
# This is a debug config
...
...
@@ -60,7 +66,7 @@ def tl_matmul(
block_col_warps
=
2
warp_row_tiles
=
32
warp_col_tiles
=
32
chunk
=
32
if
in_dtype
==
"
float16
"
else
64
chunk
=
32
if
in_dtype
==
T
.
float16
else
64
shared_scope
=
"shared.dyn"
# Pipeline Stage
...
...
@@ -110,7 +116,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -118,10 +123,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -129,7 +136,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -139,7 +145,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -215,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def
main
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"
float8_e4m3
"
,
"
float32
"
,
"
float32
"
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"
float8_e5m2
"
,
"
float32
"
,
"
float32
"
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
T
.
float8_e4m3
fn
,
T
.
float32
,
T
.
float32
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
T
.
float8_e5m2
,
T
.
float32
,
T
.
float32
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
0 → 100644
View file @
667632cc
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
map_torch_type
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
accum_dtype
)
mbar
=
T
.
alloc_barrier
(
1
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm_v2
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
(
k
==
0
),
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
M
,
N
,
K
=
4096
,
4096
,
8192
block_M
,
block_N
,
block_K
=
64
,
256
,
32
trans_A
,
trans_B
=
False
,
True
num_stages
=
2
threads
=
256
for
tvm_fp8_dtype
in
[
T
.
float8_e4m3fn
,
T
.
float8_e5m2
]:
for
tvm_acc_dtype
in
[
T
.
float16
,
T
.
float32
]:
# , torch.float16]:
torch_fp8_dtype
=
map_torch_type
(
tvm_fp8_dtype
)
torch_acc_dtype
=
map_torch_type
(
tvm_acc_dtype
)
print
(
f
"running
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
"
)
in_dtype
,
out_dtype
,
accum_dtype
=
tvm_fp8_dtype
,
tvm_acc_dtype
,
tvm_acc_dtype
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
:
True
,
},
)
# jit_kernel.export_ptx("./dump.ptx")
# jit_kernel.export_sources("./dump.cu")
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
torch_fp8_dtype
)
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
torch_fp8_dtype
)
c
=
jit_kernel
(
a
,
b
)
ref_c
=
(
a
.
to
(
torch
.
half
)
@
b
.
T
.
to
(
torch
.
half
)).
float
()
c
=
c
.
float
()
diff
=
calc_diff
(
c
,
ref_c
)
# assert diff < 1e-3, f"{diff}"
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] diff =
{
diff
}
"
)
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Latency:
{
latency
}
ms"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sm100/README.md
View file @
667632cc
...
...
@@ -40,19 +40,19 @@ import tilelang.language as T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
"
bfloat16
"
),
B
:
T
.
Tensor
((
N
,
K
),
"
bfloat16
"
),
C
:
T
.
Tensor
((
M
,
N
),
"
bfloat16
"
),
A
:
T
.
Tensor
((
M
,
K
),
T
.
bfloat16
),
B
:
T
.
Tensor
((
N
,
K
),
T
.
bfloat16
),
C
:
T
.
Tensor
((
M
,
N
),
T
.
bfloat16
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
# 1. Allocate memory buffers
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
"
bfloat16
"
)
# A matrix shared memory
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
"
bfloat16
"
)
# B matrix shared memory
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
"
float
"
)
# TCGEN5MMA output to Tensor Memory
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
T
.
bfloat16
)
# A matrix shared memory
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
T
.
bfloat16
)
# B matrix shared memory
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
T
.
float
)
# TCGEN5MMA output to Tensor Memory
mbar
=
T
.
alloc_barrier
(
1
)
# mbarrier synchronization primitive
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
"
float
"
)
# Register storage
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
"
bfloat16
"
)
# Output shared memory
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
T
.
float
)
# Register storage
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
T
.
bfloat16
)
# Output shared memory
# 2. Main computation loop
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
1
):
...
...
examples/gemm_sm100/gemm_mma.py
View file @
667632cc
...
...
@@ -4,8 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
# 3. Test the kernel in Python with PyTorch data
import
torch
...
...
examples/gemm_sm100/gemm_tcgen5mma.py
View file @
667632cc
...
...
@@ -40,15 +40,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -62,12 +54,11 @@ def matmul(
M
,
N
,
K
=
4096
,
4096
,
8192
block_M
,
block_N
,
block_K
=
128
,
256
,
128
trans_A
,
trans_B
=
False
,
True
in_dtype
,
out_dtype
,
accum_dtype
=
"
bfloat16
"
,
"
bfloat16
"
,
"
float
"
in_dtype
,
out_dtype
,
accum_dtype
=
T
.
bfloat16
,
T
.
bfloat16
,
T
.
float
num_stages
=
2
threads
=
256
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
...
...
@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
...
...
@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sp/example_custom_compress.py
0 → 100644
View file @
667632cc
import
argparse
import
tilelang
import
tilelang.language
as
T
from
tilelang.layout
import
make_cutlass_metadata_layout
from
tilelang.utils.sparse
import
randn_semi_sparse
from
tilelang.utils.tensor
import
torch_assert_close
from
triton.testing
import
do_bench
import
torch
torch
.
manual_seed
(
42
)
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
T
.
float
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
T
.
float16
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
"h20"
:
{
T
.
float
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
T
.
float16
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
(
16
,
T
.
int16
)
@
T
.
prim_func
def
gemm_sp_fp16_custom_compress
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
T
.
float16
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
T
.
float16
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
T
.
float16
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
T
.
float16
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
T
.
float16
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
T
.
float16
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm_sp_v2
(
A_shared
,
E_shared
,
B_shared
,
C_local
,
False
,
False
,
policy
=
policy
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm_sp_fp16_custom_compress
def
torch_compress
(
dense
):
"""
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
m
,
k
=
dense
.
shape
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
_m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
_m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
((
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
)
)
return
(
sparse
,
meta
)
def
decode_metadata
(
meta
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
meta
.
dtype
is
torch
.
int16
groups_per_meta
=
16
//
4
# 4 groups per uint16
out
=
[]
for
g
in
range
(
groups_per_meta
):
group_bits
=
(
meta
>>
(
g
*
4
))
&
0xF
idx0
=
group_bits
&
0x3
idx1
=
(
group_bits
>>
2
)
&
0x3
out
.
append
(
torch
.
stack
([
idx0
,
idx1
],
dim
=-
1
))
return
torch
.
concat
(
out
,
dim
=-
1
).
view
(
meta
.
shape
[
0
],
-
1
)
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TIR_DISABLE_VECTORIZE
:
True
,
},
)
def
compress_kernel
(
M
,
K
,
block_M
,
block_K
,
dtype
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
ARCH_INFO
[
"8.0"
]
e_K
=
K
//
e_factor
elem
,
group
=
2
,
4
assert
M
%
block_M
==
0
,
"M must be divisible by block_M"
assert
K
%
block_K
==
0
,
"K must be divisible by block_K"
assert
K
%
e_factor
==
0
,
"K must be divisible by e_factor"
assert
block_K
%
e_factor
==
0
,
"block_K must be divisible by e_factor"
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A_sp
:
T
.
Tensor
((
M
,
K
//
2
),
dtype
),
E
:
T
.
Tensor
((
M
,
e_K
),
e_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
K
,
block_K
),
threads
=
block_M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_sp_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
T
.
float16
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
T
.
float16
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
A_sp_shared
)
T
.
clear
(
E_shared
)
# TODO: alloc_var seems buggy here
non_zero_cnt
=
T
.
alloc_local
((
1
,),
dtype
=
T
.
uint8
)
non_zero_elt_log_idx
=
T
.
alloc_local
((
elem
,),
dtype
=
T
.
uint8
)
T
.
copy
(
A
[
bx
*
block_M
,
by
*
block_K
],
A_shared
)
for
tm
in
T
.
Parallel
(
block_M
):
for
g_i
in
range
(
0
,
block_K
//
group
):
a_k
=
g_i
*
group
non_zero_cnt
[
0
]
=
0
for
i
in
range
(
elem
):
non_zero_elt_log_idx
[
i
]
=
0
for
i
in
range
(
group
):
val
=
A_shared
[
tm
,
a_k
+
i
]
if
val
!=
0.0
:
non_zero_elt_log_idx
[
non_zero_cnt
[
0
]]
=
i
A_sp_shared
[
tm
,
a_k
//
2
+
non_zero_cnt
[
0
]]
=
val
non_zero_cnt
[
0
]
+=
1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if
non_zero_cnt
[
0
]
==
1
and
non_zero_elt_log_idx
[
0
]
==
3
:
non_zero_elt_log_idx
[
0
]
=
0
non_zero_elt_log_idx
[
1
]
=
3
A_sp_shared
[
tm
,
a_k
//
2
+
1
]
=
A_sp_shared
[
tm
,
a_k
//
2
]
A_sp_shared
[
tm
,
a_k
//
2
]
=
0.0
elif
non_zero_cnt
[
0
]
==
1
:
A_sp_shared
[
tm
,
a_k
//
2
+
1
]
=
0
non_zero_elt_log_idx
[
1
]
=
3
for
i
in
T
.
serial
(
elem
):
val
=
non_zero_elt_log_idx
[
i
]
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
T
.
copy
(
A_sp_shared
,
A_sp
[
bx
*
block_M
,
by
*
block_K
//
2
])
T
.
copy
(
E_shared
,
E
[
bx
*
block_M
,
by
*
block_K
//
e_factor
])
return
kernel
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
"store_true"
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
"store_true"
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
T
.
float
,
choices
=
[
T
.
float
,
T
.
float16
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16_custom_compress
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"cuda"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"cuda"
,
dtype
=
torch
.
half
)
if
args
.
use_torch_compressor
:
assert
not
args
.
use_cutlass_layout
,
"torch sparse must be used with naive layout"
a_sparse
,
e
=
torch_compress
(
a
)
else
:
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
T
.
float16
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
torch_assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_sp/example_gemm_sp.py
View file @
667632cc
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import
argparse
import
tilelang
import
tilelang.language
as
T
from
tilelang.layout
import
make_metadata_layout
from
tilelang.layout
import
make_
cutlass_
metadata_layout
from
tilelang.utils.sparse
import
compress
,
randn_semi_sparse
from
tilelang.contrib
import
nvcc
from
triton.testing
import
do_bench
...
...
@@ -14,86 +12,79 @@ import torch
arch
=
nvcc
.
get_target_compute_version
()
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
default_config
=
{
# take best config from autotune script
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
T
.
float
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
T
.
float16
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
T
.
float
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
T
.
float16
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
@
T
.
prim_func
def
gemm_sp_fp16
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
T
.
float16
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
T
.
float16
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
T
.
float16
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
T
.
float16
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
({
E
:
make_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
backend
=
"cutlass"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
backend
=
"cutlass"
,
block_k
=
block_K
,
arch
=
arch
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
T
.
float16
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
T
.
float16
,
block_k
=
block_K
,
arch
=
arch
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
...
...
@@ -111,25 +102,15 @@ def main():
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
required
=
True
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
T
.
float
,
choices
=
[
T
.
float
,
T
.
float16
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
default_config
[
args
.
cfg
][
args
.
accum_dtype
])
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
default_config
[
args
.
cfg
][
args
.
accum_dtype
][
'block_K'
],
arch
=
arch
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
"block_K"
],
arch
=
arch
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
...
...
@@ -144,8 +125,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
...
...
examples/
dynamic_shape
/test_example_
dynamic
.py
→
examples/
gemm_sp
/test_example_
gemm_sp
.py
View file @
667632cc
import
tilelang.testing
import
example_dynamic
import
example_custom_compress
import
example_gemm_sp
def
test_example_dynamic
():
example_dynamic
.
main
(
M
=
1024
,
N
=
1024
,
K
=
1024
)
def
test_example_custom_compress
():
example_custom_compress
.
main
()
def
test_example_gemm_sp
():
example_gemm_sp
.
main
()
if
__name__
==
"__main__"
:
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk.py
View file @
667632cc
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
,
out_dtype
=
T
.
float32
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
View file @
667632cc
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
,
out_dtype
=
T
.
float32
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_streamk/example_tilelang_gemm_streamk.py
View file @
667632cc
...
...
@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles
=
total_tiles
%
streamk_programs
if
(
total_tiles
-
streamk_tiles
>
streamk_programs
)
:
# (total_tiles // total_programs > 1)
if
total_tiles
-
streamk_tiles
>
streamk_programs
:
# (total_tiles // total_programs > 1)
streamk_tiles
+=
streamk_programs
blocking_tiles
=
total_tiles
-
streamk_tiles
...
...
@@ -87,8 +87,8 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
start_iter
=
T
.
alloc_fragment
((
1
,),
"
int32
"
,
"local"
)
end_iter
=
T
.
alloc_fragment
((
1
,),
"
int32
"
,
"local"
)
start_iter
=
T
.
alloc_fragment
((
1
,),
T
.
int32
,
"local"
)
end_iter
=
T
.
alloc_fragment
((
1
,),
T
.
int32
,
"local"
)
start_iter
[
0
]
=
pid
*
streamk_full_tiles
+
T
.
min
(
pid
,
streamk_partial_tiles
)
last_iter
=
(
pid
+
1
)
*
streamk_full_tiles
+
T
.
min
(
pid
+
1
,
streamk_partial_tiles
)
...
...
@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
for
p
in
T
.
serial
(
sm_patition_factor
):
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -155,7 +154,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
):
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
...
...
@@ -181,9 +179,9 @@ def main():
BLOCK_SIZE_K
,
False
,
True
,
"
float16
"
,
"
float16
"
,
"
float32
"
,
T
.
float16
,
T
.
float16
,
T
.
float32
,
2
,
64
,
)
...
...
examples/gemv/example_gemv.py
View file @
667632cc
...
...
@@ -17,10 +17,9 @@ def naive_gemv(
K
:
int
,
BLOCK_N
:
int
,
BLOCK_K
:
int
,
dtype
:
str
=
"
float16
"
,
accum_dtype
:
str
=
"
float
"
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -38,8 +37,7 @@ def naive_gemv(
A_shared
[
tk
]
=
A
[
bk
*
BLOCK_K
+
tk
]
B_shared
[
tn
,
tk
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
]
for
tk
in
T
.
serial
(
BLOCK_K
):
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reg
[
0
]
return
main
...
...
@@ -51,10 +49,9 @@ def naive_splitk_gemv(
K
:
int
,
BLOCK_N
:
int
,
BLOCK_K
:
int
,
dtype
:
str
=
"
float16
"
,
accum_dtype
:
str
=
"
float
"
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -88,8 +85,8 @@ def splitk_gemv(
BLOCK_N
:
int
,
BLOCK_K
:
int
,
reduce_threads
:
int
,
dtype
:
str
=
"
float16
"
,
accum_dtype
:
str
=
"
float
"
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
,
):
TILE_K
=
T
.
ceildiv
(
BLOCK_K
,
reduce_threads
)
...
...
@@ -127,8 +124,8 @@ def splitk_gemv_vectorized(
K
:
int
,
BLOCK_N
:
int
,
reduce_threads
:
int
,
dtype
:
str
=
"
float16
"
,
accum_dtype
:
str
=
"
float
"
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
,
):
MAX_TRANSACTION_SIZE_IN_BITS
=
128
TILE_K
=
MAX_TRANSACTION_SIZE_IN_BITS
//
DataType
(
dtype
).
bits
...
...
@@ -168,8 +165,8 @@ def splitk_gemv_vectorized_tvm(
K
:
int
,
BLOCK_N
:
int
,
reduce_threads
:
int
,
dtype
:
str
=
"
float16
"
,
accum_dtype
:
str
=
"
float
"
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
,
):
MAX_TRANSACTION_SIZE_IN_BITS
=
128
TILE_K
=
MAX_TRANSACTION_SIZE_IN_BITS
//
DataType
(
dtype
).
bits
...
...
@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
...
@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def
get_block_template_configs
():
iter_params
=
dict
(
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
])
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
]
)
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -237,18 +233,11 @@ def get_block_template_configs():
},
out_idx
=
[
2
],
)
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
T
.
dtype
=
T
.
float16
,
accum_dtype
:
T
.
dtype
=
T
.
float
):
@
T
.
prim_func
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
T
.
clear
(
o_reducer
)
...
...
@@ -287,8 +276,8 @@ def get_autotuned_kernel(
BLOCK_N
=
None
,
reduce_threads
=
None
,
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
MAX_TRANSACTION_SIZE_IN_BITS
=
128
TILE_K
=
MAX_TRANSACTION_SIZE_IN_BITS
//
DataType
(
dtype
).
bits
BLOCK_K
=
reduce_threads
*
TILE_K
...
...
@@ -327,17 +316,18 @@ def get_autotuned_kernel(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
return
main
def
check_correctness_and_bench
(
kernel
,
N
,
K
,
bench
_ref
=
True
):
def
check_correctness_and_bench
(
kernel
,
N
,
K
,
do_
bench
=
True
):
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
lambda
x
,
y
:
x
@
y
.
T
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
bench
_ref
:
if
do_
bench
:
latency
=
profiler
.
do_bench
(
lambda
x
,
y
:
x
@
y
.
T
,
warmup
=
50
)
print
(
f
"Torch Latency:
{
latency
}
ms"
)
latency
=
profiler
.
do_bench
(
kernel
,
warmup
=
50
)
...
...
@@ -350,16 +340,16 @@ def main(do_bench: bool = True):
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
args
,
_
=
parser
.
parse_known_args
()
N
,
K
=
args
.
n
,
args
.
k
check_correctness_and_bench
(
naive_gemv
(
N
,
K
,
128
,
128
),
N
,
K
)
check_correctness_and_bench
(
naive_splitk_gemv
(
N
,
K
,
32
,
32
),
N
,
K
)
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
)
check_correctness_and_bench
(
naive_gemv
(
N
,
K
,
128
,
128
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
naive_splitk_gemv
(
N
,
K
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
print
(
"Test passed!"
)
if
not
do_bench
:
if
do_bench
:
best_result
=
get_autotuned_kernel
(
N
,
K
)
best_config
=
best_result
.
config
kernel
=
splitk_gemv_vectorized_tvm
(
N
,
K
,
**
best_config
)
...
...
examples/gemv/test_example_gemv.py
View file @
667632cc
import
tilelang.testing
import
example_gemv
...
...
@@ -8,4 +6,4 @@ def test_example_gemv():
if
__name__
==
"__main__"
:
t
ilelang
.
testing
.
main
()
t
est_example_gemv
()
examples/grouped_gemm/example_grouped_gemm_bwd.py
View file @
667632cc
...
...
@@ -5,67 +5,45 @@ import tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
T
.
float16
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype
=
"
float32
"
accum_dtype
=
T
.
float32
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
([
batch_sum
,
K
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_count
,
K
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"
int32
"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"
int32
"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"
int32
"
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
T
.
int32
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
T
.
int32
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
T
.
int32
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
cur_batch_idx
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_batch_size
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_batch_idx
=
T
.
alloc_local
([
1
],
T
.
int32
)
cur_batch_size
=
T
.
alloc_local
([
1
],
T
.
int32
)
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class
_GroupedGEMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
b
,
batch_sizes
):
block_M
=
64
...
...
@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_offsets
=
torch
.
tensor
(
batch_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
o
=
kernel
(
a
,
b
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ctx
.
save_for_backward
(
a
,
b
,
batch_sizes
,
batch_offsets
)
...
...
@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return
x
A
,
B
,
batch_sizes
=
[
maybe_contiguous
(
x
)
for
x
in
(
A
,
B
,
batch_sizes
)]
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
dB
=
kernel
(
A
,
grad_output
,
batch_sizes
,
batch_offsets
)
return
None
,
dB
,
None
...
...
@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -187,40 +157,24 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
T
.
float16
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype
=
"
float32
"
accum_dtype
=
T
.
float32
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
([
batch_sum
,
M
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_count
,
M
,
N
],
dtype
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"
int32
"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"
int32
"
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
T
.
int32
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
T
.
int32
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
([
block_K
,
block_M
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
...
@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
batch_sizes
[
bz
],
block_K
),
num_stages
=
num_stages
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_M
):
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_A
=
True
)
T
.
copy
(
C_local
,
C
[
bz
,
bx
*
block_M
,
by
*
block_N
])
...
...
@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return
kernel
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
.
requires_grad_
(
False
)
B
.
requires_grad_
(
True
)
...
...
@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O
.
backward
(
dO
,
retain_graph
=
True
)
dB
,
B
.
grad
=
B
.
grad
.
clone
(),
None
if
(
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
\
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
):
if
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"✅ Tilelang and Torch match"
)
else
:
print
(
"❌ Tilelang and Torch mismatch"
)
...
...
@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
...
@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages
=
2
threads
=
256
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment