Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
667 deletions
+378
-667
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+3
-4
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+24
-65
examples/gemm/example_gemm_intrinsics.py
examples/gemm/example_gemm_intrinsics.py
+11
-11
examples/gemm/example_gemm_persistent.py
examples/gemm/example_gemm_persistent.py
+20
-43
examples/gemm/example_gemm_schedule.py
examples/gemm/example_gemm_schedule.py
+3
-4
examples/gemm_fp8/example_tilelang_gemm_amd.py
examples/gemm_fp8/example_tilelang_gemm_amd.py
+28
-49
examples/gemm_fp8/example_tilelang_gemm_fp8.py
examples/gemm_fp8/example_tilelang_gemm_fp8.py
+7
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
+8
-8
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+11
-11
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
+4
-6
examples/gemm_sm100/gemm_mma.py
examples/gemm_sm100/gemm_mma.py
+5
-5
examples/gemm_sm100/gemm_tcgen5mma.py
examples/gemm_sm100/gemm_tcgen5mma.py
+8
-16
examples/gemm_sp/example_custom_compress.py
examples/gemm_sp/example_custom_compress.py
+81
-107
examples/gemm_sp/example_gemm_sp.py
examples/gemm_sp/example_gemm_sp.py
+55
-68
examples/gemm_splitk/example_tilelang_gemm_splitk.py
examples/gemm_splitk/example_tilelang_gemm_splitk.py
+5
-16
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
...plitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
+5
-16
examples/gemm_streamk/example_tilelang_gemm_streamk.py
examples/gemm_streamk/example_tilelang_gemm_streamk.py
+4
-6
examples/gemv/example_gemv.py
examples/gemv/example_gemv.py
+34
-47
examples/grouped_gemm/example_grouped_gemm_bwd.py
examples/grouped_gemm/example_grouped_gemm_bwd.py
+38
-113
examples/grouped_gemm/example_grouped_gemm_fwd.py
examples/grouped_gemm/example_grouped_gemm_fwd.py
+24
-64
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/gemm/example_gemm.py
View file @
29051439
...
@@ -4,12 +4,11 @@ import tilelang.language as T
...
@@ -4,12 +4,11 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/gemm/example_gemm_autotune.py
View file @
29051439
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages
,
num_stages
,
thread_num
,
thread_num
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages"
:
c
[
3
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False):
...
@@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return
main
return
main
autotuner
=
AutoTuner
.
from_kernel
(
autotuner
=
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
target
=
"auto"
,
target
=
"auto"
,
).
set_profile_args
(
)
.
set_profile_args
(
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
ref_prog
=
ref_program
,
skip_check
=
False
,
skip_check
=
False
,
)
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
...
@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
if
sm_version
in
{
80
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
elif
sm_version
in
{
90
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
else
:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tl
.
jit
(
out_idx
=
[
-
1
])
@
tl
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm_autotune
(
def
gemm_autotune
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -236,11 +207,7 @@ def matmul(M,
...
@@ -236,11 +207,7 @@ def matmul(M,
return
gemm_autotune
return
gemm_autotune
def
main
(
M
:
int
=
4096
,
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
use_autotune
=
True
use_autotune
=
True
if
use_autotune
:
if
use_autotune
:
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
"--use_autotune"
,
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
examples/gemm/example_gemm_intrinsics.py
View file @
29051439
...
@@ -4,7 +4,8 @@ import tilelang
...
@@ -4,7 +4,8 @@ import tilelang
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
...
@@ -99,12 +100,11 @@ def tl_matmul(
...
@@ -99,12 +100,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
gemm_intrinsics
(
def
gemm_intrinsics
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -112,10 +112,12 @@ def tl_matmul(
...
@@ -112,10 +112,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -123,7 +125,6 @@ def tl_matmul(
...
@@ -123,7 +125,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -133,7 +134,6 @@ def tl_matmul(
...
@@ -133,7 +134,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
examples/gemm/example_gemm_persistent.py
View file @
29051439
...
@@ -5,22 +5,12 @@ import argparse
...
@@ -5,22 +5,12 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_non_persistent
(
M
,
def
matmul_non_persistent
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
...
@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_persistent
(
M
,
def
matmul_persistent
(
N
,
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
K
,
):
block_M
,
block_N
,
block_K
,
threads
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
use_persistent_primitive
=
True
):
sm_num
=
driver
.
get_num_sms
()
sm_num
=
driver
.
get_num_sms
()
m_blocks
=
T
.
ceildiv
(
M
,
block_M
)
m_blocks
=
T
.
ceildiv
(
M
,
block_M
)
n_blocks
=
T
.
ceildiv
(
N
,
block_N
)
n_blocks
=
T
.
ceildiv
(
N
,
block_N
)
...
@@ -63,9 +44,9 @@ def matmul_persistent(M,
...
@@ -63,9 +44,9 @@ def matmul_persistent(M,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
sm_num
,
threads
=
threads
)
as
(
block_id
):
with
T
.
Kernel
(
sm_num
,
threads
=
threads
)
as
(
block_id
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -90,9 +71,9 @@ def matmul_persistent(M,
...
@@ -90,9 +71,9 @@ def matmul_persistent(M,
@
T
.
prim_func
@
T
.
prim_func
def
main_persistent_primitive
(
def
main_persistent_primitive
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
sm_num
,
threads
=
threads
)
as
(
block_id
):
with
T
.
Kernel
(
sm_num
,
threads
=
threads
)
as
(
block_id
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -100,8 +81,7 @@ def matmul_persistent(M,
...
@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
for
bx
,
by
in
T
.
Persistent
(
for
bx
,
by
in
T
.
Persistent
([
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
[
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
)],
sm_num
,
block_id
):
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
bx
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A
[
bx
*
block_M
,
k
*
block_K
],
A_shared
)
...
@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
...
@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages
=
3
num_stages
=
3
persistent_kernel
=
matmul_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
persistent_kernel
=
matmul_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
persistent_profiler
=
persistent_kernel
.
get_profiler
(
persistent_profiler
=
persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Persistent GEMM: All check passed."
)
print
(
"Persistent GEMM: All check passed."
)
persistent_latency
=
persistent_profiler
.
do_bench
(
warmup
=
500
)
persistent_latency
=
persistent_profiler
.
do_bench
(
warmup
=
500
)
print
(
f
"Persistent GEMM Latency:
{
persistent_latency
}
ms"
)
print
(
f
"Persistent GEMM Latency:
{
persistent_latency
}
ms"
)
print
(
f
"Persistent GEMM TFlops:
{
total_flops
/
persistent_latency
*
1e-9
}
TFlops"
)
print
(
f
"Persistent GEMM TFlops:
{
total_flops
/
persistent_latency
*
1e-9
}
TFlops"
)
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
non_persistent_kernel
=
matmul_non_persistent
(
M
,
N
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
threads
,
num_stages
)
num_stages
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_profiler
=
non_persistent_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Randn
)
non_persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
non_persistent_profiler
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"Non-Persistent GEMM: All check passed."
)
print
(
"Non-Persistent GEMM: All check passed."
)
non_persistent_latency
=
non_persistent_profiler
.
do_bench
(
warmup
=
500
)
non_persistent_latency
=
non_persistent_profiler
.
do_bench
(
warmup
=
500
)
...
@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
...
@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--M
'
,
type
=
int
,
default
=
8192
,
help
=
'
M dimension
'
)
parser
.
add_argument
(
"
--M
"
,
type
=
int
,
default
=
8192
,
help
=
"
M dimension
"
)
parser
.
add_argument
(
'
--N
'
,
type
=
int
,
default
=
8192
,
help
=
'
N dimension
'
)
parser
.
add_argument
(
"
--N
"
,
type
=
int
,
default
=
8192
,
help
=
"
N dimension
"
)
parser
.
add_argument
(
'
--K
'
,
type
=
int
,
default
=
8192
,
help
=
'
K dimension
'
)
parser
.
add_argument
(
"
--K
"
,
type
=
int
,
default
=
8192
,
help
=
"
K dimension
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
M
,
N
,
K
=
args
.
M
,
args
.
N
,
args
.
K
M
,
N
,
K
=
args
.
M
,
args
.
N
,
args
.
K
main
(
M
,
N
,
K
)
main
(
M
,
N
,
K
)
examples/gemm/example_gemm_schedule.py
View file @
29051439
...
@@ -4,12 +4,11 @@ import tilelang.language as T
...
@@ -4,12 +4,11 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm_schedule
(
def
gemm_schedule
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/gemm_fp8/example_tilelang_gemm_amd.py
View file @
29051439
...
@@ -17,10 +17,8 @@ def supply_prog(args):
...
@@ -17,10 +17,8 @@ def supply_prog(args):
a_param
,
b_param
=
args
a_param
,
b_param
=
args
M
,
K
=
a_param
.
shape
M
,
K
=
a_param
.
shape
N
,
_
=
b_param
.
shape
N
,
_
=
b_param
.
shape
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
return
[
a
,
b
]
return
[
a
,
b
]
...
@@ -35,27 +33,24 @@ def get_configs():
...
@@ -35,27 +33,24 @@ def get_configs():
valid_configs
=
[]
valid_configs
=
[]
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
num_stages
,
num_threads
,
k_packs
,
valid_configs
.
append
(
gemm_types
):
{
valid_configs
.
append
({
"block_M"
:
m
,
"block_
M
"
:
m
,
"block_
N
"
:
n
,
"block_
N
"
:
n
,
"block_
K
"
:
k
,
"block_K"
:
k
,
"num_stages"
:
stages
,
"num_
stages"
:
stages
,
"num_
threads"
:
t
,
"num_threads
"
:
t
,
"k_pack
"
:
kp
,
"k_pack"
:
kp
,
"gemm_type"
:
gemm_type
,
"gemm_type"
:
gemm_type
,
}
}
)
)
return
valid_configs
return
valid_configs
@
tilelang
.
autotune
(
@
tilelang
.
autotune
(
configs
=
get_configs
(),
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
cache_input_tensors
=
True
,
)
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
dtype
=
"float8_e4m3fnuz"
dtype
=
"float8_e4m3fnuz"
...
@@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
...
@@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
@
T
.
prim_func
@
T
.
prim_func
def
gemm_fp8_rs
(
def
gemm_fp8_rs
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
...
@@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
@
T
.
prim_func
@
T
.
prim_func
def
gemm_fp8_ss
(
def
gemm_fp8_ss
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
...
@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
...
@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def
test_gemm_fp8
(
M
,
N
,
K
):
def
test_gemm_fp8
(
M
,
N
,
K
):
kernel
=
fp8_matmul
(
M
,
N
,
K
)
kernel
=
fp8_matmul
(
M
,
N
,
K
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
c
=
kernel
(
a
,
b
)
c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8.py
View file @
29051439
...
@@ -13,12 +13,11 @@ def calc_diff(x, y):
...
@@ -13,12 +13,11 @@ def calc_diff(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm_fp8
(
def
gemm_fp8
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
...
@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
c
=
kernel
(
a
,
b
)
...
@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
...
@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
def
main
():
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
View file @
29051439
...
@@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
...
@@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@
T
.
prim_func
@
T
.
prim_func
def
gemm_fp8_2xAcc
(
def
gemm_fp8_2xAcc
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
...
@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
float
()
@
b
.
float
().
T
)
ref_c
=
a
.
float
()
@
b
.
float
().
T
diff
=
calc_diff
(
c
,
ref_c
)
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
print
(
f
"diff:
{
diff
}
"
)
...
@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
...
@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def
main
():
def
main
():
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e4m3
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e4m3
"
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
'
float8_e5m2
'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
"
float8_e5m2
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
View file @
29051439
...
@@ -5,7 +5,8 @@ from tvm import DataType
...
@@ -5,7 +5,8 @@ from tvm import DataType
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
from
tilelang.utils.tensor
import
map_torch_type
...
@@ -110,12 +111,11 @@ def tl_matmul(
...
@@ -110,12 +111,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
gemm_fp8_intrinsic
(
def
gemm_fp8_intrinsic
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -123,10 +123,12 @@ def tl_matmul(
...
@@ -123,10 +123,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -134,7 +136,6 @@ def tl_matmul(
...
@@ -134,7 +136,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -144,7 +145,6 @@ def tl_matmul(
...
@@ -144,7 +145,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
...
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
View file @
29051439
...
@@ -26,9 +26,9 @@ def matmul(
...
@@ -26,9 +26,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
...
@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
latency
=
profiler
.
do_bench
()
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Latency:
{
latency
}
ms"
)
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Latency:
{
latency
}
ms"
)
print
(
print
(
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
f
"[
{
tvm_fp8_dtype
}
->
{
tvm_acc_dtype
}
] Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sm100/gemm_mma.py
View file @
29051439
...
@@ -5,12 +5,11 @@ import tilelang.language as T
...
@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
...
@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
print
(
jit_kernel
.
get_kernel_source
())
# 3. Test the kernel in Python with PyTorch data
# 3. Test the kernel in Python with PyTorch data
import
torch
import
torch
...
...
examples/gemm_sm100/gemm_tcgen5mma.py
View file @
29051439
...
@@ -25,9 +25,9 @@ def matmul(
...
@@ -25,9 +25,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -40,15 +40,7 @@ def matmul(
...
@@ -40,15 +40,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
T
.
copy
(
C_tmem
,
C_local
)
...
@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
...
@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages
=
2
num_stages
=
2
threads
=
256
threads
=
256
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
accum_dtype
,
num_stages
,
threads
)
jit_kernel
=
tilelang
.
compile
(
jit_kernel
=
tilelang
.
compile
(
func
,
func
,
out_idx
=
[
2
],
out_idx
=
[
2
],
...
@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
...
@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
jit_kernel
.
get_kernel_source
())
print
(
jit_kernel
.
get_kernel_source
())
...
@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
...
@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sp/example_custom_compress.py
View file @
29051439
...
@@ -17,77 +17,76 @@ torch.manual_seed(42)
...
@@ -17,77 +17,76 @@ torch.manual_seed(42)
DEFAULT_CONFIG
=
{
# take best config from autotune script
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
"4090"
:
{
'float'
:
{
"float"
:
{
'block_M'
:
128
,
"block_M"
:
128
,
'block_N'
:
64
,
"block_N"
:
64
,
'block_K'
:
64
,
"block_K"
:
64
,
'num_stages'
:
1
,
"num_stages"
:
1
,
'thread_num'
:
128
,
"thread_num"
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
},
"h20"
:
{
"h20"
:
{
'float'
:
{
"float"
:
{
'block_M'
:
128
,
"block_M"
:
128
,
'block_N'
:
64
,
"block_N"
:
64
,
'block_K'
:
128
,
"block_K"
:
128
,
'num_stages'
:
3
,
"num_stages"
:
3
,
'thread_num'
:
128
,
"thread_num"
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
},
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16_custom_compress
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
def
matmul_sp_fp16_custom_compress
(
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
(
16
,
"int16"
)
e_factor
,
e_dtype
=
(
16
,
"int16"
)
@
T
.
prim_func
@
T
.
prim_func
def
gemm_sp_fp16_custom_compress
(
def
gemm_sp_fp16_custom_compress
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
if
use_cutlass_layout
:
if
use_cutlass_layout
:
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
}
make_cutlass_metadata_layout
(
)
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
...
@@ -108,8 +107,7 @@ def torch_compress(dense):
...
@@ -108,8 +107,7 @@ def torch_compress(dense):
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
"""
if
dense
.
dim
()
!=
2
:
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
)
m
,
k
=
dense
.
shape
m
,
k
=
dense
.
shape
...
@@ -131,9 +129,7 @@ def torch_compress(dense):
...
@@ -131,9 +129,7 @@ def torch_compress(dense):
if
m
%
32
!=
0
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
)
if
dense
.
dtype
!=
torch
.
float
:
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
ksparse
=
4
...
@@ -194,19 +190,13 @@ def torch_compress(dense):
...
@@ -194,19 +190,13 @@ def torch_compress(dense):
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
else
:
sparse
=
dense_2
.
gather
(
-
1
,
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
((
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
meta_n
=
meta_4
.
view
((
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta
=
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta
=
(
meta_n
[:,
:,
0
]
meta_n
[:,
:,
0
]
...
@@ -216,7 +206,8 @@ def torch_compress(dense):
...
@@ -216,7 +206,8 @@ def torch_compress(dense):
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
|
(
meta_n
[:,
:,
7
]
<<
28
)
)
return
(
sparse
,
meta
)
return
(
sparse
,
meta
)
...
@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
...
@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
1
,
2
],
pass_configs
=
{
out_idx
=
[
1
,
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TIR_DISABLE_VECTORIZE
:
True
,
tilelang
.
PassConfigKey
.
TIR_DISABLE_VECTORIZE
:
True
,
})
},
)
def
compress_kernel
(
M
,
K
,
block_M
,
block_K
,
dtype
,
use_cutlass_layout
):
def
compress_kernel
(
M
,
K
,
block_M
,
block_K
,
dtype
,
use_cutlass_layout
):
e_factor
,
e_dtype
=
ARCH_INFO
[
"8.0"
]
e_factor
,
e_dtype
=
ARCH_INFO
[
"8.0"
]
e_K
=
K
//
e_factor
e_K
=
K
//
e_factor
...
@@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
...
@@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A_sp
:
T
.
Tensor
((
M
,
K
//
2
),
dtype
),
A_sp
:
T
.
Tensor
((
M
,
K
//
2
),
dtype
),
E
:
T
.
Tensor
((
M
,
e_K
),
e_dtype
),
E
:
T
.
Tensor
((
M
,
e_K
),
e_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
K
,
block_K
),
threads
=
block_M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
K
,
block_K
),
threads
=
block_M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_sp_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
dtype
)
A_sp_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
if
use_cutlass_layout
:
if
use_cutlass_layout
:
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
E_shared
:
}
make_cutlass_metadata_layout
(
)
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"8.0"
,
block_k
=
block_K
),
})
T
.
clear
(
A_sp_shared
)
T
.
clear
(
A_sp_shared
)
T
.
clear
(
E_shared
)
T
.
clear
(
E_shared
)
# TODO: alloc_var seems buggy here
# TODO: alloc_var seems buggy here
...
@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
...
@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
non_zero_elt_log_idx
[
1
]
=
3
non_zero_elt_log_idx
[
1
]
=
3
for
i
in
T
.
serial
(
elem
):
for
i
in
T
.
serial
(
elem
):
val
=
non_zero_elt_log_idx
[
i
]
val
=
non_zero_elt_log_idx
[
i
]
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
E_shared
[
tm
,
a_k
//
e_factor
]
|=
T
.
shift_left
(
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
val
,
4
*
(
g_i
%
(
e_factor
//
group
))
+
2
*
i
)
T
.
copy
(
A_sp_shared
,
A_sp
[
bx
*
block_M
,
by
*
block_K
//
2
])
T
.
copy
(
A_sp_shared
,
A_sp
[
bx
*
block_M
,
by
*
block_K
//
2
])
T
.
copy
(
E_shared
,
E
[
bx
*
block_M
,
by
*
block_K
//
e_factor
])
T
.
copy
(
E_shared
,
E
[
bx
*
block_M
,
by
*
block_K
//
e_factor
])
...
@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
...
@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_cutlass_layout"
,
action
=
"store_true"
,
help
=
"Use cutlass layout for E tensor"
)
"--use_cutlass_layout"
,
action
=
'store_true'
,
help
=
"Use cutlass layout for E tensor"
)
parser
.
add_argument
(
"--use_torch_compressor"
,
action
=
"store_true"
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
"--use_torch_compressor"
,
action
=
'store_true'
,
help
=
"Use torch sparse for reference"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
],
default
=
"4090"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16_custom_compress
(
kernel
=
matmul_sp_fp16_custom_compress
(
args
.
m
,
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
args
.
n
,
)
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
],
use_cutlass_layout
=
args
.
use_cutlass_layout
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
if
args
.
use_torch_compressor
:
if
args
.
use_torch_compressor
:
assert
not
args
.
use_cutlass_layout
,
"torch sparse must be used with naive layout"
assert
not
args
.
use_cutlass_layout
,
"torch sparse must be used with naive layout"
a_sparse
,
e
=
torch_compress
(
a
)
a_sparse
,
e
=
torch_compress
(
a
)
else
:
else
:
a_sparse
,
e
=
compress_kernel
(
a_sparse
,
e
=
compress_kernel
(
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
args
.
m
,
args
.
k
,
32
,
32
,
"float16"
,
use_cutlass_layout
=
args
.
use_cutlass_layout
)(
a
)
c
=
kernel
(
a_sparse
,
e
,
b
)
c
=
kernel
(
a_sparse
,
e
,
b
)
...
@@ -346,9 +322,7 @@ def main():
...
@@ -346,9 +322,7 @@ def main():
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
torch_assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
torch_assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
print
(
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
f
"Precision check passed. Max diff:
{
(
c
-
ref_c
).
abs
().
max
()
}
, Mean diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
...
@@ -356,8 +330,8 @@ def main():
...
@@ -356,8 +330,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gemm_sp/example_gemm_sp.py
View file @
29051439
...
@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
...
@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG
=
{
# take best config from autotune script
DEFAULT_CONFIG
=
{
# take best config from autotune script
"4090"
:
{
"4090"
:
{
'float'
:
{
"float"
:
{
'block_M'
:
128
,
"block_M"
:
128
,
'block_N'
:
64
,
"block_N"
:
64
,
'block_K'
:
64
,
"block_K"
:
64
,
'num_stages'
:
1
,
"num_stages"
:
1
,
'thread_num'
:
128
,
"thread_num"
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
256
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
},
"h20"
:
{
"h20"
:
{
'float'
:
{
"float"
:
{
'block_M'
:
128
,
"block_M"
:
128
,
'block_N'
:
64
,
"block_N"
:
64
,
'block_K'
:
128
,
"block_K"
:
128
,
'num_stages'
:
3
,
"num_stages"
:
3
,
'thread_num'
:
128
,
"thread_num"
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
"enable_rasterization"
:
True
,
},
"float16"
:
{
"block_M"
:
128
,
"block_N"
:
64
,
"block_K"
:
128
,
"num_stages"
:
3
,
"thread_num"
:
128
,
"policy"
:
T
.
GemmWarpPolicy
.
Square
,
"enable_rasterization"
:
True
,
},
},
'float16'
:
{
},
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
enable_rasterization
):
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
@
T
.
prim_func
@
T
.
prim_func
def
gemm_sp_fp16
(
def
gemm_sp_fp16
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'
float16
'
),
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
"
float16
"
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'
float16
'
),
B
:
T
.
Tensor
((
K
,
N
),
"
float16
"
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'
float16
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
"
float16
"
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'
float16
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
"
float16
"
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
}
make_cutlass_metadata_layout
(
)
E_shared
,
mma_dtype
=
"float16"
,
block_k
=
block_K
,
arch
=
arch
),
})
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
...
@@ -107,25 +104,15 @@ def main():
...
@@ -107,25 +104,15 @@ def main():
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
default
=
"4090"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
default
=
"4090"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
**
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
])
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'
cuda
'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
"
cuda
"
,
dtype
=
torch
.
half
)
a_sparse
,
e
=
compress
(
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
"block_K"
],
arch
=
arch
)
a
,
transposed
=
False
,
block_k
=
DEFAULT_CONFIG
[
args
.
cfg
][
args
.
accum_dtype
][
'block_K'
],
arch
=
arch
)
c
=
kernel
(
a_sparse
,
e
,
b
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
ref_c
=
a
@
b
...
@@ -140,8 +127,8 @@ def main():
...
@@ -140,8 +127,8 @@ def main():
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk.py
View file @
29051439
...
@@ -3,27 +3,16 @@ import tilelang.language as T
...
@@ -3,27 +3,16 @@ import tilelang.language as T
@
tilelang
.
jit
@
tilelang
.
jit
def
matmul
(
M
,
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
splitK
=
K
//
split_k
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
View file @
29051439
...
@@ -3,27 +3,16 @@ import tilelang.language as T
...
@@ -3,27 +3,16 @@ import tilelang.language as T
@
tilelang
.
jit
@
tilelang
.
jit
def
matmul
(
M
,
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
splitK
=
K
//
split_k
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
...
...
examples/gemm_streamk/example_tilelang_gemm_streamk.py
View file @
29051439
...
@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
...
@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
# Two-tile SK + DP
streamk_tiles
=
total_tiles
%
streamk_programs
streamk_tiles
=
total_tiles
%
streamk_programs
if
(
total_tiles
-
streamk_tiles
>
streamk_programs
)
:
# (total_tiles // total_programs > 1)
if
total_tiles
-
streamk_tiles
>
streamk_programs
:
# (total_tiles // total_programs > 1)
streamk_tiles
+=
streamk_programs
streamk_tiles
+=
streamk_programs
blocking_tiles
=
total_tiles
-
streamk_tiles
blocking_tiles
=
total_tiles
-
streamk_tiles
...
@@ -135,7 +135,6 @@ def tl_matmul_streamk(
...
@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C
:
T
.
Tensor
,
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
C_local
:
T
.
LocalBuffer
,
):
):
for
p
in
T
.
serial
(
sm_patition_factor
):
for
p
in
T
.
serial
(
sm_patition_factor
):
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
...
@@ -150,12 +149,11 @@ def tl_matmul_streamk(
...
@@ -150,12 +149,11 @@ def tl_matmul_streamk(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
dtypeAB
),
A
:
T
.
Tensor
(
A_shape
,
dtypeAB
),
B
:
T
.
Tensor
(
B_shape
,
dtypeAB
),
B
:
T
.
Tensor
(
B_shape
,
dtypeAB
),
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
):
):
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
...
...
examples/gemv/example_gemv.py
View file @
29051439
...
@@ -20,12 +20,11 @@ def naive_gemv(
...
@@ -20,12 +20,11 @@ def naive_gemv(
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
accum_dtype
:
str
=
"float"
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
# tn = threadIdx.x
tn
=
T
.
get_thread_binding
(
0
)
# tn = threadIdx.x
...
@@ -38,8 +37,7 @@ def naive_gemv(
...
@@ -38,8 +37,7 @@ def naive_gemv(
A_shared
[
tk
]
=
A
[
bk
*
BLOCK_K
+
tk
]
A_shared
[
tk
]
=
A
[
bk
*
BLOCK_K
+
tk
]
B_shared
[
tn
,
tk
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
]
B_shared
[
tn
,
tk
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
]
for
tk
in
T
.
serial
(
BLOCK_K
):
for
tk
in
T
.
serial
(
BLOCK_K
):
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
C_reg
[
0
]
+=
A_shared
[
tk
].
astype
(
accum_dtype
)
*
B_shared
[
tn
,
tk
].
astype
(
accum_dtype
)
tk
].
astype
(
accum_dtype
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reg
[
0
]
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reg
[
0
]
return
main
return
main
...
@@ -54,12 +52,11 @@ def naive_splitk_gemv(
...
@@ -54,12 +52,11 @@ def naive_splitk_gemv(
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
,
accum_dtype
:
str
=
"float"
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
BLOCK_K
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
BLOCK_K
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tn
=
T
.
get_thread_binding
(
0
)
...
@@ -95,9 +92,9 @@ def splitk_gemv(
...
@@ -95,9 +92,9 @@ def splitk_gemv(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tn
=
T
.
get_thread_binding
(
0
)
...
@@ -136,9 +133,9 @@ def splitk_gemv_vectorized(
...
@@ -136,9 +133,9 @@ def splitk_gemv_vectorized(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tn
=
T
.
get_thread_binding
(
0
)
...
@@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm(
...
@@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tn
=
T
.
get_thread_binding
(
0
)
...
@@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm(
...
@@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm(
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
with
T
.
attr
(
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
):
T
.
evaluate
(
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
tvm_thread_allreduce
(
...
@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
...
@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced
[
0
],
C_reduced
[
0
],
tk
,
tk
,
dtype
=
"handle"
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
...
@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def
get_block_template_configs
():
def
get_block_template_configs
():
iter_params
=
dict
(
iter_params
=
dict
(
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
]
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
)
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
...
@@ -237,18 +233,9 @@ def get_block_template_configs():
...
@@ -237,18 +233,9 @@ def get_block_template_configs():
},
},
out_idx
=
[
2
],
out_idx
=
[
2
],
)
)
def
gemv_alloc_reducer
(
M
,
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
dtype
)):
# type: ignore
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
T
.
clear
(
o_reducer
)
T
.
clear
(
o_reducer
)
...
@@ -295,9 +282,9 @@ def get_autotuned_kernel(
...
@@ -295,9 +282,9 @@ def get_autotuned_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tn
=
T
.
get_thread_binding
(
0
)
...
@@ -315,9 +302,9 @@ def get_autotuned_kernel(
...
@@ -315,9 +302,9 @@ def get_autotuned_kernel(
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
with
T
.
attr
(
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
):
T
.
evaluate
(
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
tvm_thread_allreduce
(
...
@@ -327,7 +314,8 @@ def get_autotuned_kernel(
...
@@ -327,7 +314,8 @@ def get_autotuned_kernel(
C_reduced
[
0
],
C_reduced
[
0
],
tk
,
tk
,
dtype
=
"handle"
,
dtype
=
"handle"
,
))
)
)
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
...
@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
...
@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
,
do_bench
=
do_bench
)
check_correctness_and_bench
(
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
,
do_bench
=
do_bench
)
print
(
"Test passed!"
)
print
(
"Test passed!"
)
...
...
examples/grouped_gemm/example_grouped_gemm_bwd.py
View file @
29051439
...
@@ -5,21 +5,8 @@ import tilelang
...
@@ -5,21 +5,8 @@ import tilelang
import
tilelang.language
as
T
import
tilelang.language
as
T
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
out_idx
=
[
2
],
pass_configs
=
{
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
"""
args:
args:
a (torch.Tensor): Input tensor of shape (M, K).
a (torch.Tensor): Input tensor of shape (M, K).
...
@@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum,
...
@@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum,
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
([
batch_sum
,
K
],
dtype
),
# type: ignore
A
:
T
.
Tensor
([
batch_sum
,
K
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_count
,
K
,
N
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_count
,
K
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
batch_sum
,
block_M
)
+
batch_count
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
...
@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
m_start_padded
=
bx
*
block_M
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
...
@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class
_GroupedGEMM
(
torch
.
autograd
.
Function
):
class
_GroupedGEMM
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
a
,
b
,
batch_sizes
):
def
forward
(
ctx
,
a
,
b
,
batch_sizes
):
block_M
=
64
block_M
=
64
...
@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
...
@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes
[
i
])
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes
[
i
])
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
math
.
ceil
((
batch_sizes
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
batch_offsets
=
torch
.
tensor
(
batch_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_offsets
=
torch
.
tensor
(
batch_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets
=
torch
.
tensor
(
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
batch_padded_offsets_list
,
device
=
a
.
device
,
dtype
=
torch
.
int32
)
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
kernel
=
grouped_gemm_fwd
(
batch_sum
,
batch_count
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
num_stages
,
threads
)
o
=
kernel
(
a
,
b
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
o
=
kernel
(
a
,
b
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ctx
.
save_for_backward
(
a
,
b
,
batch_sizes
,
batch_offsets
)
ctx
.
save_for_backward
(
a
,
b
,
batch_sizes
,
batch_offsets
)
...
@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
...
@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return
x
return
x
A
,
B
,
batch_sizes
=
[
maybe_contiguous
(
x
)
for
x
in
(
A
,
B
,
batch_sizes
)]
A
,
B
,
batch_sizes
=
[
maybe_contiguous
(
x
)
for
x
in
(
A
,
B
,
batch_sizes
)]
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
kernel
=
grouped_gemm_bwd
(
ctx
.
batch_sum
,
ctx
.
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
num_stages
,
threads
)
dB
=
kernel
(
A
,
grad_output
,
batch_sizes
,
batch_offsets
)
dB
=
kernel
(
A
,
grad_output
,
batch_sizes
,
batch_offsets
)
return
None
,
dB
,
None
return
None
,
dB
,
None
...
@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
...
@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
math
.
ceil
((
batch_sizes_list
[
i
]
+
1
)
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
...
@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
out_idx
=
[
2
],
pass_configs
=
{
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
grouped_gemm_bwd
(
batch_sum
,
batch_count
,
M
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
"""
args:
args:
a (torch.Tensor): Input tensor of shape (M, K).
a (torch.Tensor): Input tensor of shape (M, K).
...
@@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum,
...
@@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum,
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
([
batch_sum
,
M
],
dtype
),
# type: ignore
A
:
T
.
Tensor
([
batch_sum
,
M
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_count
,
M
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_count
,
M
,
N
],
dtype
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
batch_count
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
([
block_K
,
block_M
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_K
,
block_M
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
C_local
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
...
@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
...
@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
batch_sizes
[
bz
],
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
batch_sizes
[
bz
],
block_K
),
num_stages
=
num_stages
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_M
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_M
):
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
A_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
i
<
batch_sizes
[
bz
],
A
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
bx
*
block_M
+
j
],
0
)
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
B_shared
[
i
,
j
]
=
T
.
if_then_else
(
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
i
<
batch_sizes
[
bz
],
B
[
batch_offsets
[
bz
]
+
k
*
block_K
+
i
,
by
*
block_N
+
j
],
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_A
=
True
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_A
=
True
)
T
.
copy
(
C_local
,
C
[
bz
,
bx
*
block_M
,
by
*
block_N
])
T
.
copy
(
C_local
,
C
[
bz
,
bx
*
block_M
,
by
*
block_N
])
...
@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
...
@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return
kernel
return
kernel
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
padding_M
=
block_M
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
batch_sizes_list
,
K
,
M
,
False
,
padding_M
,
device
,
dtype
)
A
.
requires_grad_
(
False
)
A
.
requires_grad_
(
False
)
B
.
requires_grad_
(
True
)
B
.
requires_grad_
(
True
)
...
@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
...
@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O
.
backward
(
dO
,
retain_graph
=
True
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dB
,
B
.
grad
=
B
.
grad
.
clone
(),
None
dB
,
B
.
grad
=
B
.
grad
.
clone
(),
None
if
(
if
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
):
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
and
\
torch
.
allclose
(
dB
,
dB_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
):
print
(
"✅ Tilelang and Torch match"
)
print
(
"✅ Tilelang and Torch match"
)
else
:
else
:
print
(
"❌ Tilelang and Torch mismatch"
)
print
(
"❌ Tilelang and Torch mismatch"
)
...
@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
...
@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
@@ -301,14 +236,4 @@ if __name__ == "__main__":
...
@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages
=
2
num_stages
=
2
threads
=
256
threads
=
256
run_tilelang_grouped_gemm
(
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
examples/grouped_gemm/example_grouped_gemm_fwd.py
View file @
29051439
...
@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
...
@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
torch.Tensor: Resulting tensor after grouped matrix multiplication.
torch.Tensor: Resulting tensor after grouped matrix multiplication.
"""
"""
assert
a
.
shape
[
0
]
==
sum
(
batch_sizes
),
"Sum of batch_sizes must equal the first dimension of a"
assert
a
.
shape
[
0
]
==
sum
(
batch_sizes
),
"Sum of batch_sizes must equal the first dimension of a"
assert
b
.
shape
[
0
]
==
len
(
assert
b
.
shape
[
0
]
==
len
(
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
batch_sizes
),
"The first dimension of b must match the length of batch_sizes"
# Initialize output tensor
# Initialize output tensor
output
=
torch
.
empty
((
sum
(
batch_sizes
),
b
.
shape
[
2
]),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
output
=
torch
.
empty
((
sum
(
batch_sizes
),
b
.
shape
[
2
]),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
...
@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
...
@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
grouped_gemm
(
batch_sizes_list
,
def
grouped_gemm
(
batch_sizes_list
,
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
K
,
N
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
threads
=
128
,
dtype
=
"float16"
):
"""
"""
args:
args:
a (torch.Tensor): Input tensor of shape (M, K).
a (torch.Tensor): Input tensor of shape (M, K).
...
@@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list,
...
@@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list,
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
([
batch_sum
,
K
],
dtype
),
# type: ignore
A
:
T
.
Tensor
([
batch_sum
,
K
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_count
,
K
,
N
],
dtype
),
# type: ignore
B
:
T
.
Tensor
([
batch_count
,
K
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
C
:
T
.
Tensor
([
batch_sum
,
N
],
dtype
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_sizes
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
batch_padded_offsets
:
T
.
Tensor
([
batch_count
],
"int32"
),
# type: ignore
):
):
with
T
.
Kernel
(
total_m_blocks
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
total_m_blocks
,
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_M
,
block_K
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
B_shared
=
T
.
alloc_shared
([
block_K
,
block_N
],
dtype
)
...
@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
...
@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
m_start_padded
=
bx
*
block_M
m_start_padded
=
bx
*
block_M
for
i
in
range
(
batch_count
):
for
i
in
range
(
batch_count
):
in_cur_batch_idx
=
(
m_start_padded
>=
batch_padded_offsets
[
i
]
)
in_cur_batch_idx
=
m_start_padded
>=
batch_padded_offsets
[
i
]
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_idx
[
0
]
=
T
.
if_then_else
(
in_cur_batch_idx
,
i
,
cur_batch_idx
[
0
])
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
cur_batch_size
[
0
]
=
batch_sizes
[
cur_batch_idx
[
0
]]
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
m_start
=
m_start_padded
-
batch_padded_offsets
[
cur_batch_idx
[
0
]]
+
batch_offsets
[
cur_batch_idx
[
0
]]
cur_batch_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_M
,
cur_batch_size
[
0
]
+
batch_padded_offsets
[
cur_batch_idx
[
0
]]
-
m_start_padded
))
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
A
[
m_start
:
m_start
+
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
],
A_shared
)
T
.
copy
(
T
.
copy
(
B
[
cur_batch_idx
[
0
],
k
*
block_K
:
(
k
+
1
)
*
block_K
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
B_shared
)
B
[
cur_batch_idx
[
0
],
k
*
block_K
:(
k
+
1
)
*
block_K
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
...
@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
batch_offsets_list
.
append
(
batch_offsets_list
[
-
1
]
+
batch_sizes_list
[
i
])
for
i
in
range
(
batch_count
-
1
):
for
i
in
range
(
batch_count
-
1
):
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
batch_padded_offsets_list
.
append
(
batch_padded_offsets_list
[
-
1
]
+
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
math
.
ceil
((
batch_sizes_list
[
i
])
/
padding_M
)
*
padding_M
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
A
=
torch
.
randn
(
batch_sum
,
K
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
B
=
torch
.
randn
(
batch_count
,
K
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
C
=
torch
.
empty
(
batch_sum
,
M
,
device
=
device
,
dtype
=
dtype
)
...
@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
...
@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
return
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
def
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
=
2
,
threads
=
128
,
profile
=
False
):
padding_M
=
block_M
padding_M
=
block_M
batch_sum
=
sum
(
batch_sizes_list
)
batch_sum
=
sum
(
batch_sizes_list
)
kernel
=
grouped_gemm
(
kernel
=
grouped_gemm
(
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
tuple
(
batch_sizes_list
),
K
,
M
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
)
# print(kernel.get_kernel_source())
# print(kernel.get_kernel_source())
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
A
,
B
,
C
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
=
construct_inputs
(
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
batch_sizes_list
,
K
,
M
,
trans_b
,
padding_M
,
device
,
dtype
)
out
=
kernel
(
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
out
=
kernel
(
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
)
ref_output
=
torch_gmm
(
A
,
B
,
batch_sizes
,
batch_offsets
,
trans_b
)
ref_output
=
torch_gmm
(
A
,
B
,
batch_sizes
,
batch_offsets
,
trans_b
)
# print(out)
# print(out)
...
@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
...
@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if
profile
:
if
profile
:
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
latency
=
profiler
.
do_bench
(
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
warmup
=
500
,
input_tensors
=
[
A
,
B
,
batch_sizes
,
batch_offsets
,
batch_padded_offsets
])
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"TFlops:
{
batch_sum
*
K
*
M
*
2
/
latency
*
1e-9
}
TFlops"
)
print
(
f
"TFlops:
{
batch_sum
*
K
*
M
*
2
/
latency
*
1e-9
}
TFlops"
)
...
@@ -173,12 +144,11 @@ def test_grouped_gemm():
...
@@ -173,12 +144,11 @@ def test_grouped_gemm():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_sizes"
,
type
=
str
,
default
=
"64, 128"
,
help
=
"comma-separated batch sizes"
)
'--batch_sizes'
,
type
=
str
,
default
=
"64, 128"
,
help
=
'comma-separated batch sizes'
)
parser
.
add_argument
(
"--K"
,
type
=
int
,
default
=
8192
,
help
=
"reduce dim"
)
parser
.
add_argument
(
'--K'
,
type
=
int
,
default
=
8192
,
help
=
'reduce dim'
)
parser
.
add_argument
(
"--M"
,
type
=
int
,
default
=
8192
,
help
=
"output dim"
)
parser
.
add_argument
(
'--M'
,
type
=
int
,
default
=
8192
,
help
=
'output dim'
)
parser
.
add_argument
(
"--trans_b"
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
'--trans_b'
,
action
=
"store_true"
,
help
=
"transpose B"
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"profile"
)
parser
.
add_argument
(
'--profile'
,
action
=
"store_true"
,
help
=
"profile"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
batch_sizes_list
=
[
int
(
x
)
for
x
in
args
.
batch_sizes
.
split
(
","
)]
...
@@ -190,14 +160,4 @@ if __name__ == "__main__":
...
@@ -190,14 +160,4 @@ if __name__ == "__main__":
num_stages
=
2
num_stages
=
2
threads
=
256
threads
=
256
run_tilelang_grouped_gemm
(
run_tilelang_grouped_gemm
(
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
batch_sizes_list
,
K
,
M
,
block_M
,
block_N
,
block_K
,
trans_b
,
num_stages
,
threads
,
profile
=
args
.
profile
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment