Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
667 deletions
+378
-667
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+3
-4
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+24
-65
examples/gemm/example_gemm_intrinsics.py
examples/gemm/example_gemm_intrinsics.py
+11
-11
examples/gemm/example_gemm_persistent.py
examples/gemm/example_gemm_persistent.py
+20
-43
examples/gemm/example_gemm_schedule.py
examples/gemm/example_gemm_schedule.py
+3
-4
examples/gemm_fp8/example_tilelang_gemm_amd.py
examples/gemm_fp8/example_tilelang_gemm_amd.py
+28
-49
examples/gemm_fp8/example_tilelang_gemm_fp8.py
examples/gemm_fp8/example_tilelang_gemm_fp8.py
+7
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
+8
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+11
-11
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
+4
-6
examples/gemm_sm100/gemm_mma.py
examples/gemm_sm100/gemm_mma.py
+5
-5
examples/gemm_sm100/gemm_tcgen5mma.py
examples/gemm_sm100/gemm_tcgen5mma.py
+8
-16
examples/gemm_sp/example_custom_compress.py
examples/gemm_sp/example_custom_compress.py
+81
-107
examples/gemm_sp/example_gemm_sp.py
examples/gemm_sp/example_gemm_sp.py
+55
-68
examples/gemm_splitk/example_tilelang_gemm_splitk.py
examples/gemm_splitk/example_tilelang_gemm_splitk.py
+5
-16
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
...plitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
+5
-16
examples/gemm_streamk/example_tilelang_gemm_streamk.py
examples/gemm_streamk/example_tilelang_gemm_streamk.py
+4
-6
examples/gemv/example_gemv.py
examples/gemv/example_gemv.py
+34
-47
examples/grouped_gemm/example_grouped_gemm_bwd.py
examples/grouped_gemm/example_grouped_gemm_bwd.py
+38
-113
examples/grouped_gemm/example_grouped_gemm_fwd.py
examples/grouped_gemm/example_grouped_gemm_fwd.py
+24
-64
No files found.
examples/gemm/example_gemm.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/gemm/example_gemm_autotune.py
View file @
29051439
...
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -124,8 +125,7 @@ def get_best_config(M, N, K, with_roller=False):
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return
main
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
autotuner
=
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
target
=
"auto"
,
).
set_profile_args
(
)
.
set_profile_args
(
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
@@ -167,47 +170,15 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tl
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_autotune
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -236,11 +207,7 @@ def matmul(M,
return
gemm_autotune
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
use_autotune
=
True
if
use_autotune
:
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
...
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
examples/gemm/example_gemm_intrinsics.py
View file @
29051439
...
...
@@ -4,7 +4,8 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
...
...
@@ -104,7 +105,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -112,10 +112,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -123,7 +125,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -133,7 +134,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
examples/gemm/example_gemm_persistent.py
View file @
29051439
...
...
@@ -5,17 +5,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
def
matmul_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
sm_num
=
driver
.
get_num_sms
()
m_blocks
=
T
.
ceildiv
(
M
,
block_M
)
n_blocks
=
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
for
bx
,
by
in
T
.
Persistent
(
[
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
for
bx
,
by
in
T
.
Persistent
([
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
bx
*
block_M
,
k
*
block_K
],
A_shared
)
...
...
@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages
=
3
persistent_kernel
=
matmul_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Persistent GEMM: All check passed."
)
persistent_latency
=
persistent_profiler
.
do_bench
(
warmup
=
500
)
print
(
f
"Persistent GEMM Latency:
{
persistent_latency
}
ms"
)
print
(
f
"Persistent GEMM TFlops:
{
total_flops
/
persistent_latency
*
1e-9
}
TFlops"
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Non-Persistent GEMM: All check passed."
)
non_persistent_latency
=
non_persistent_profiler
.
do_bench
(
warmup
=
500
)
...
...
@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--M
'
,
type
=
int
,
default
=
8192
,
help
=
'
M dimension
'
)
parser
.
add_argument
(
'
--N
'
,
type
=
int
,
default
=
8192
,
help
=
'
N dimension
'
)
parser
.
add_argument
(
'
--K
'
,
type
=
int
,
default
=
8192
,
help
=
'
K dimension
'
)
parser
.
add_argument
(
"
--M
"
,
type
=
int
,
default
=
8192
,
help
=
"
M dimension
"
)
parser
.
add_argument
(
"
--N
"
,
type
=
int
,
default
=
8192
,
help
=
"
N dimension
"
)
parser
.
add_argument
(
"
--K
"
,
type
=
int
,
default
=
8192
,
help
=
"
K dimension
"
)
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
M
,
args
.
N
,
args
.
K
main
(
M
,
N
,
K
)
examples/gemm/example_gemm_schedule.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_schedule
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/gemm_fp8/example_tilelang_gemm_amd.py
View file @
29051439
...
...
@@ -17,10 +17,8 @@ def supply_prog(args):
a_param
,
b_param
=
args
M
,
K
=
a_param
.
shape
N
,
_
=
b_param
.
shape
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
return
[
a
,
b
]
...
...
@@ -35,10 +33,9 @@ def get_configs():
valid_configs
=
[]
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
({
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
(
{
"block_M"
:
m
,
"block_N"
:
n
,
"block_K"
:
k
,
...
...
@@ -46,16 +43,14 @@ def get_configs():
"num_threads"
:
t
,
"k_pack"
:
kp
,
"gemm_type"
:
gemm_type
,
})
}
)
return
valid_configs
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
dtype
=
"float8_e4m3fnuz"
...
...
@@ -67,8 +62,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -77,13 +71,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -93,8 +81,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def
test_gemm_fp8
(
M
,
N
,
K
):
kernel
=
fp8_matmul
(
M
,
N
,
K
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8.py
View file @
29051439
...
...
@@ -13,7 +13,6 @@ def calc_diff(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_fp8
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
...
...
@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
View file @
29051439
...
...
@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
float
()
@
b
.
float
().
T
)
ref_c
=
a
.
float
()
@
b
.
float
().
T
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
...
...
@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
View file @
29051439
...
...
@@ -5,7 +5,8 @@ from tvm import DataType
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
...
...
@@ -115,7 +116,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -123,10 +123,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -134,7 +136,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -144,7 +145,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
View file @
29051439
...
...
@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Latency:
{
latency
}
ms"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sm100/gemm_mma.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
# 3. Test the kernel in Python with PyTorch data
import
torch
...
...
examples/gemm_sm100/gemm_tcgen5mma.py
View file @
29051439
...
...
@@ -40,15 +40,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages
=
2
threads
=
256
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
...
...
@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
...
...
@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sp/example_custom_compress.py
View file @
29051439
...
...
@@ -17,77 +17,76 @@ torch.manual_seed(42)
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
(
16
,
"int16"
)
@
T
.
prim_func
def
gemm_sp_fp16_custom_compress
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
...
...
@@ -108,8 +107,7 @@ def torch_compress(dense):
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
m
,
k
=
dense
.
shape
...
...
@@ -131,9 +129,7 @@ def torch_compress(dense):
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
...
...
@@ -194,19 +190,13 @@ def torch_compress(dense):
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
((
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
meta
=
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
...
...
@@ -216,7 +206,8 @@ def torch_compress(dense):
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
|
(
meta_n
[:,
:,
7
]
<<
28
)
)
return
(
sparse
,
meta
)
...
...
@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
],
pass_configs
=
{
out_idx
=
[
1
,
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TIR_DISABLE_VECTORIZE
:
True
,
})
},
)
def
compress_kernel
(
M
,
K
,
block_M
,
block_K
,
dtype
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
ARCH_INFO
[
"8.0"
]
e_K
=
K
//
e_factor
...
...
@@ -258,14 +251,12 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
A_sp_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
if
use_cutlass_layout
:
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
A_sp_shared
)
T
.
clear
(
E_shared
)
# TODO: alloc_var seems buggy here
...
...
@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
non_zero_elt_log_idx
[
1
]
=
3
for
i
in
T
.
serial
(
elem
):
val
=
non_zero_elt_log_idx
[
i
]
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
T
.
copy
(
A_sp_shared
,
A_sp
[
bx
*
block_M
,
by
*
block_K
//
2
])
T
.
copy
(
E_shared
,
E
[
bx
*
block_M
,
by
*
block_K
//
e_factor
])
...
...
@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
'store_true'
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
'store_true'
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
"store_true"
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
"store_true"
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16_custom_compress
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
if
args
.
use_torch_compressor
:
assert
not
args
.
use_cutlass_layout
,
"torch sparse must be used with naive layout"
a_sparse
,
e
=
torch_compress
(
a
)
else
:
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
c
=
kernel
(
a_sparse
,
e
,
b
)
...
...
@@ -346,9 +322,7 @@ def main():
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
torch_assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
...
...
@@ -356,8 +330,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_sp/example_gemm_sp.py
View file @
29051439
...
...
@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
64
,
"num_stages"
:
1
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"float"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
@
T
.
prim_func
def
gemm_sp_fp16
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
...
...
@@ -107,25 +104,15 @@ def main():
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
'block_K'
],
arch
=
arch
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
"block_K"
],
arch
=
arch
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
...
...
@@ -140,8 +127,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk.py
View file @
29051439
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
View file @
29051439
...
...
@@ -3,17 +3,7 @@ import tilelang.language as T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
...
...
@@ -22,8 +12,7 @@ def matmul(M,
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_streamk/example_tilelang_gemm_streamk.py
View file @
29051439
...
...
@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles
=
total_tiles
%
streamk_programs
if
(
total_tiles
-
streamk_tiles
>
streamk_programs
)
:
# (total_tiles // total_programs > 1)
if
total_tiles
-
streamk_tiles
>
streamk_programs
:
# (total_tiles // total_programs > 1)
streamk_tiles
+=
streamk_programs
blocking_tiles
=
total_tiles
-
streamk_tiles
...
...
@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
for
p
in
T
.
serial
(
sm_patition_factor
):
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
...
...
@@ -155,7 +154,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
):
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
...
...
examples/gemv/example_gemv.py
View file @
29051439
...
...
@@ -20,7 +20,6 @@ def naive_gemv(
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -38,8 +37,7 @@ def naive_gemv(
A_shared
[
tk
]
=
A
[
bk
*
BLOCK_K
+
tk
]
B_shared
[
tn
,
tk
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
]
for
tk
in
T
.
serial
(
BLOCK_K
):
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reg
[
0
]
return
main
...
...
@@ -54,7 +52,6 @@ def naive_splitk_gemv(
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
...
...
@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
...
@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def
get_block_template_configs
():
iter_params
=
dict
(
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
])
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
]
)
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
...
@@ -237,18 +233,9 @@ def get_block_template_configs():
},
out_idx
=
[
2
],
)
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
@
T
.
prim_func
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
T
.
clear
(
o_reducer
)
...
...
@@ -327,7 +314,8 @@ def get_autotuned_kernel(
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
...
@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
print
(
"Test passed!"
)
...
...
examples/grouped_gemm/example_grouped_gemm_bwd.py
View file @
29051439
...
...
@@ -5,21 +5,8 @@ import tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -36,10 +23,7 @@ def grouped_gemm_fwd(batch_sum,
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
...
@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class
_GroupedGEMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
b
,
batch_sizes
):
block_M
=
64
...
...
@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_offsets
=
torch
.
tensor
(
batch_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
o
=
kernel
(
a
,
b
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ctx
.
save_for_backward
(
a
,
b
,
batch_sizes
,
batch_offsets
)
...
...
@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return
x
A
,
B
,
batch_sizes
=
[
maybe_contiguous
(
x
)
for
x
in
(
A
,
B
,
batch_sizes
)]
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
dB
=
kernel
(
A
,
grad_output
,
batch_sizes
,
batch_offsets
)
return
None
,
dB
,
None
...
...
@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -217,10 +174,7 @@ def grouped_gemm_bwd(batch_sum,
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
([
block_K
,
block_M
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
...
@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
batch_sizes
[
bz
],
block_K
),
num_stages
=
num_stages
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_M
):
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_A
=
True
)
T
.
copy
(
C_local
,
C
[
bz
,
bx
*
block_M
,
by
*
block_N
])
...
...
@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return
kernel
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
.
requires_grad_
(
False
)
B
.
requires_grad_
(
True
)
...
...
@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O
.
backward
(
dO
,
retain_graph
=
True
)
dB
,
B
.
grad
=
B
.
grad
.
clone
(),
None
if
(
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
\
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
):
if
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
):
print
(
"✅ Tilelang and Torch match"
)
else
:
print
(
"❌ Tilelang and Torch mismatch"
)
...
...
@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
...
@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages
=
2
threads
=
256
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
examples/grouped_gemm/example_grouped_gemm_fwd.py
View file @
29051439
...
...
@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
torch.Tensor: Resulting tensor after grouped matrix multiplication.
"""
assert
a
.
shape
[
0
]
==
sum
(
batch_sizes
),
"Sum of batch_sizes must equal the first dimension of a"
assert
b
.
shape
[
0
]
==
len
(
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
assert
b
.
shape
[
0
]
==
len
(
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
# Initialize output tensor
output
=
torch
.
empty
((
sum
(
batch_sizes
),
b
.
shape
[
2
]),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
...
...
@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
grouped_gemm
(
batch_sizes_list
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
def
grouped_gemm
(
batch_sizes_list
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
...
...
@@ -66,7 +57,6 @@ def grouped_gemm(batch_sizes_list,
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
with
T
.
Kernel
(
total_m_blocks
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
...
...
@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
batch_sum
=
sum
(
batch_sizes_list
)
kernel
=
grouped_gemm
(
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
kernel
=
grouped_gemm
(
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
# print(kernel.get_kernel_source())
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
out
=
kernel
(
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ref_output
=
torch_gmm
(
A
,
B
,
batch_sizes
,
batch_offsets
,
trans_b
)
# print(out)
...
...
@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
profile
:
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
latency
=
profiler
.
do_bench
(
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"TFlops:
{
batch_sum
*
K
*
M
*
2
/
latency
*
1e-9
}
TFlops"
)
...
...
@@ -173,12 +144,11 @@ def test_grouped_gemm():
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
...
@@ -190,14 +160,4 @@ if __name__ == "__main__":
num_stages
=
2
threads
=
256
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment