Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
667 deletions
+378
-667
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+3
-4
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+24
-65
examples/gemm/example_gemm_intrinsics.py
examples/gemm/example_gemm_intrinsics.py
+11
-11
examples/gemm/example_gemm_persistent.py
examples/gemm/example_gemm_persistent.py
+20
-43
examples/gemm/example_gemm_schedule.py
examples/gemm/example_gemm_schedule.py
+3
-4
examples/gemm_fp8/example_tilelang_gemm_amd.py
examples/gemm_fp8/example_tilelang_gemm_amd.py
+28
-49
examples/gemm_fp8/example_tilelang_gemm_fp8.py
examples/gemm_fp8/example_tilelang_gemm_fp8.py
+7
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
+8
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+11
-11
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
+4
-6
examples/gemm_sm100/gemm_mma.py
examples/gemm_sm100/gemm_mma.py
+5
-5
examples/gemm_sm100/gemm_tcgen5mma.py
examples/gemm_sm100/gemm_tcgen5mma.py
+8
-16
examples/gemm_sp/example_custom_compress.py
examples/gemm_sp/example_custom_compress.py
+81
-107
examples/gemm_sp/example_gemm_sp.py
examples/gemm_sp/example_gemm_sp.py
+55
-68
examples/gemm_splitk/example_tilelang_gemm_splitk.py
examples/gemm_splitk/example_tilelang_gemm_splitk.py
+5
-16
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
...plitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
+5
-16
examples/gemm_streamk/example_tilelang_gemm_streamk.py
examples/gemm_streamk/example_tilelang_gemm_streamk.py
+4
-6
examples/gemv/example_gemv.py
examples/gemv/example_gemv.py
+34
-47
examples/grouped_gemm/example_grouped_gemm_bwd.py
examples/grouped_gemm/example_grouped_gemm_bwd.py
+38
-113
examples/grouped_gemm/example_grouped_gemm_fwd.py
examples/grouped_gemm/example_grouped_gemm_fwd.py
+24
-64
No files found.
examples/gemm/example_gemm.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/gemm/example_gemm_autotune.py
View file @
29051439
...
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -124,8 +125,7 @@ def get_best_config(M, N, K, with_roller=False):
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return
main
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
autotuner
=
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
target
=
"auto"
,
).
set_profile_args
(
)
.
set_profile_args
(
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
@@ -167,47 +170,15 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tl
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_autotune
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -236,11 +207,7 @@ def matmul(M,
return
gemm_autotune
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
use_autotune
=
True
if
use_autotune
:
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
...
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
examples/gemm/example_gemm_intrinsics.py
View file @
29051439
...
...
@@ -4,7 +4,8 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
...
...
@@ -104,7 +105,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -112,10 +112,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -123,7 +125,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -133,7 +134,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
examples/gemm/example_gemm_persistent.py
View file @
29051439
...
...
@@ -5,17 +5,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
sm_num
=
driver
.
get_num_sms
()
m_blocks
=
T
.
ceildiv
(
M
,
block_M
)
n_blocks
=
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
for
bx
,
by
in
T
.
Persistent
(
[
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
for
bx
,
by
in
T
.
Persistent
([
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
bx
*
block_M
,
k
*
block_K
],
A_shared
)
...
...
@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages
=
3
persistent_kernel
=
matmul_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Persistent GEMM: All check passed."
)
persistent_latency
=
persistent_profiler
.
do_bench
(
warmup
=
500
)
print
(
f
"Persistent GEMM Latency:
{
persistent_latency
}
ms"
)
print
(
f
"Persistent GEMM TFlops:
{
total_flops
/
persistent_latency
*
1e-9
}
TFlops"
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Non-Persistent GEMM: All check passed."
)
non_persistent_latency
=
non_persistent_profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--M
'
,
type
=
int
,
default
=
8192
,
help
=
'
M dimension
'
)
parser
.
add_argument
(
'
--N
'
,
type
=
int
,
default
=
8192
,
help
=
'
N dimension
'
)
parser
.
add_argument
(
'
--K
'
,
type
=
int
,
default
=
8192
,
help
=
'
K dimension
'
)
parser
.
add_argument
(
"
--M
"
,
type
=
int
,
default
=
8192
,
help
=
"
M dimension
"
)
parser
.
add_argument
(
"
--N
"
,
type
=
int
,
default
=
8192
,
help
=
"
N dimension
"
)
parser
.
add_argument
(
"
--K
"
,
type
=
int
,
default
=
8192
,
help
=
"
K dimension
"
)
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
M
,
args
.
N
,
args
.
K
main
(
M
,
N
,
K
)
examples/gemm/example_gemm_schedule.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_schedule
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/gemm_fp8/example_tilelang_gemm_amd.py
View file @
29051439
...
...
@@ -17,10 +17,8 @@ def supply_prog(args):
a_param
,
b_param
=
args
M
,
K
=
a_param
.
shape
N
,
_
=
b_param
.
shape
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
return
[
a
,
b
]
...
...
@@ -35,10 +33,9 @@ def get_configs():
valid_configs
=
[]
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
({
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_N"
:
n
,
"block_K"
:
k
,
...
...
@@ -46,16 +43,14 @@ def get_configs():
"num_threads"
:
t
,
"k_pack"
:
kp
,
"gemm_type"
:
gemm_type
,
})
}
)
return
valid_configs
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
dtype
=
"float8_e4m3fnuz"
...
...
@@ -67,8 +62,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -77,13 +71,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -93,8 +81,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def
test_gemm_fp8
(
M
,
N
,
K
):
kernel
=
fp8_matmul
(
M
,
N
,
K
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8.py
View file @
29051439
...
...
@@ -13,7 +13,6 @@ def calc_diff(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_fp8
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
...
...
@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
View file @
29051439
...
...
@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
float
()
@
b
.
float
().
T
)
ref_c
=
a
.
float
()
@
b
.
float
().
T
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
...
...
@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
View file @
29051439
...
...
@@ -5,7 +5,8 @@ from tvm import DataType
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
...
...
@@ -115,7 +116,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -123,10 +123,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -134,7 +136,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -144,7 +145,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
View file @
29051439
...
...
@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Latency:
{
latency
}
ms"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sm100/gemm_mma.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
# 3. Test the kernel in Python with PyTorch data
import
torch
...
...
examples/gemm_sm100/gemm_tcgen5mma.py
View file @
29051439
...
...
@@ -40,15 +40,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages
=
2
threads
=
256
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
...
...
@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
...
...
@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sp/example_custom_compress.py
View file @
29051439
...
...
@@ -17,77 +17,76 @@ torch.manual_seed(42)
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
(
16
,
"int16"
)
@
T
.
prim_func
def
gemm_sp_fp16_custom_compress
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
...
...
@@ -108,8 +107,7 @@ def torch_compress(dense):
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
m
,
k
=
dense
.
shape
...
...
@@ -131,9 +129,7 @@ def torch_compress(dense):
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
...
...
@@ -194,19 +190,13 @@ def torch_compress(dense):
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
((
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
meta
=
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
...
...
@@ -216,7 +206,8 @@ def torch_compress(dense):
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
|
(
meta_n
[:,
:,
7
]
<<
28
)
)
return
(
sparse
,
meta
)
...
...
@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
],
pass_configs
=
{
out_idx
=
[
1
,
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TIR_DISABLE_VECTORIZE
:
True
,
})
},
)
def
compress_kernel
(
M
,
K
,
block_M
,
block_K
,
dtype
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
ARCH_INFO
[
"8.0"
]
e_K
=
K
//
e_factor
...
...
@@ -258,14 +251,12 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
A_sp_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
A_sp_shared
)
T
.
clear
(
E_shared
)
# TODO: alloc_var seems buggy here
...
...
@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
non_zero_elt_log_idx
[
1
]
=
3
for
i
in
T
.
serial
(
elem
):
val
=
non_zero_elt_log_idx
[
i
]
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
T
.
copy
(
A_sp_shared
,
A_sp
[
bx
*
block_M
,
by
*
block_K
//
2
])
T
.
copy
(
E_shared
,
E
[
bx
*
block_M
,
by
*
block_K
//
e_factor
])
...
...
@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
'store_true'
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
'store_true'
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
"store_true"
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
"store_true"
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16_custom_compress
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
if
args
.
use_torch_compressor
:
assert
not
args
.
use_cutlass_layout
,
"torch sparse must be used with naive layout"
a_sparse
,
e
=
torch_compress
(
a
)
else
:
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
c
=
kernel
(
a_sparse
,
e
,
b
)
...
...
@@ -346,9 +322,7 @@ def main():
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
torch_assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
...
...
@@ -356,8 +330,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_sp/example_gemm_sp.py
View file @
29051439
...
...
@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
@
T
.
prim_func
def
gemm_sp_fp16
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
...
...
@@ -107,25 +104,15 @@ def main():
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
'block_K'
],
arch
=
arch
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
"block_K"
],
arch
=
arch
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
...
...
@@ -140,8 +127,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk.py
View file @
29051439
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
View file @
29051439
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_streamk/example_tilelang_gemm_streamk.py
View file @
29051439
...
...
@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles
=
total_tiles
%
streamk_programs
if
(
total_tiles
-
streamk_tiles
>
streamk_programs
)
:
# (total_tiles // total_programs > 1)
if
total_tiles
-
streamk_tiles
>
streamk_programs
:
# (total_tiles // total_programs > 1)
streamk_tiles
+=
streamk_programs
blocking_tiles
=
total_tiles
-
streamk_tiles
...
...
@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
for
p
in
T
.
serial
(
sm_patition_factor
):
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -155,7 +154,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
):
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
...
...
examples/gemv/example_gemv.py
View file @
29051439
...
...
@@ -20,7 +20,6 @@ def naive_gemv(
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -38,8 +37,7 @@ def naive_gemv(
A_shared
[
tk
]
=
A
[
bk
*
BLOCK_K
+
tk
]
B_shared
[
tn
,
tk
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
]
for
tk
in
T
.
serial
(
BLOCK_K
):
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reg
[
0
]
return
main
...
...
@@ -54,7 +52,6 @@ def naive_splitk_gemv(
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
...
@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def
get_block_template_configs
():
iter_params
=
dict
(
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
])
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
]
)
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -237,18 +233,9 @@ def get_block_template_configs():
},
out_idx
=
[
2
],
)
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
@
T
.
prim_func
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
T
.
clear
(
o_reducer
)
...
...
@@ -327,7 +314,8 @@ def get_autotuned_kernel(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
...
@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
print
(
"Test passed!"
)
...
...
examples/grouped_gemm/example_grouped_gemm_bwd.py
View file @
29051439
...
...
@@ -5,21 +5,8 @@ import tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -36,10 +23,7 @@ def grouped_gemm_fwd(batch_sum,
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
...
@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class
_GroupedGEMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
b
,
batch_sizes
):
block_M
=
64
...
...
@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_offsets
=
torch
.
tensor
(
batch_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
o
=
kernel
(
a
,
b
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ctx
.
save_for_backward
(
a
,
b
,
batch_sizes
,
batch_offsets
)
...
...
@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return
x
A
,
B
,
batch_sizes
=
[
maybe_contiguous
(
x
)
for
x
in
(
A
,
B
,
batch_sizes
)]
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
dB
=
kernel
(
A
,
grad_output
,
batch_sizes
,
batch_offsets
)
return
None
,
dB
,
None
...
...
@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -217,10 +174,7 @@ def grouped_gemm_bwd(batch_sum,
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
([
block_K
,
block_M
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
...
@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
batch_sizes
[
bz
],
block_K
),
num_stages
=
num_stages
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_M
):
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_A
=
True
)
T
.
copy
(
C_local
,
C
[
bz
,
bx
*
block_M
,
by
*
block_N
])
...
...
@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return
kernel
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
.
requires_grad_
(
False
)
B
.
requires_grad_
(
True
)
...
...
@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O
.
backward
(
dO
,
retain_graph
=
True
)
dB
,
B
.
grad
=
B
.
grad
.
clone
(),
None
if
(
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
\
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
):
if
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"✅ Tilelang and Torch match"
)
else
:
print
(
"❌ Tilelang and Torch mismatch"
)
...
...
@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
...
@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages
=
2
threads
=
256
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
examples/grouped_gemm/example_grouped_gemm_fwd.py
View file @
29051439
...
...
@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
torch.Tensor: Resulting tensor after grouped matrix multiplication.
"""
assert
a
.
shape
[
0
]
==
sum
(
batch_sizes
),
"Sum of batch_sizes must equal the first dimension of a"
assert
b
.
shape
[
0
]
==
len
(
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
assert
b
.
shape
[
0
]
==
len
(
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
# Initialize output tensor
output
=
torch
.
empty
((
sum
(
batch_sizes
),
b
.
shape
[
2
]),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
...
...
@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
grouped_gemm
(
batch_sizes_list
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
def
grouped_gemm
(
batch_sizes_list
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -66,7 +57,6 @@ def grouped_gemm(batch_sizes_list,
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
total_m_blocks
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
...
...
@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
batch_sum
=
sum
(
batch_sizes_list
)
kernel
=
grouped_gemm
(
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm
(
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
# print(kernel.get_kernel_source())
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
out
=
kernel
(
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ref_output
=
torch_gmm
(
A
,
B
,
batch_sizes
,
batch_offsets
,
trans_b
)
# print(out)
...
...
@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
profile
:
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
latency
=
profiler
.
do_bench
(
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"TFlops:
{
batch_sum
*
K
*
M
*
2
/
latency
*
1e-9
}
TFlops"
)
...
...
@@ -173,12 +144,11 @@ def test_grouped_gemm():
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
...
@@ -190,14 +160,4 @@ if __name__ == "__main__":
num_stages
=
2
threads
=
256
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment