Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangkx1
tilelang
Commits
bc2d5632
Commit
bc2d5632
authored
Jan 15, 2026
by
root
Browse files
init
parents
Pipeline
#3222
failed with stages
in 0 seconds
Changes
257
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1443 additions
and
0 deletions
+1443
-0
examples/gemm/example_gemm_schedule.py
examples/gemm/example_gemm_schedule.py
+69
-0
examples/gemm/test_example_gemm.py
examples/gemm/test_example_gemm.py
+26
-0
examples/gemm_fp8/README.md
examples/gemm_fp8/README.md
+2
-0
examples/gemm_fp8/example_tilelang_gemm_amd.py
examples/gemm_fp8/example_tilelang_gemm_amd.py
+137
-0
examples/gemm_fp8/example_tilelang_gemm_fp8.py
examples/gemm_fp8/example_tilelang_gemm_fp8.py
+65
-0
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
+82
-0
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+223
-0
examples/gemm_fp8/test_example_gemm_fp8.py
examples/gemm_fp8/test_example_gemm_fp8.py
+20
-0
examples/gemm_sm100/README.md
examples/gemm_sm100/README.md
+106
-0
examples/gemm_sm100/gemm_mma.py
examples/gemm_sm100/gemm_mma.py
+94
-0
examples/gemm_sm100/gemm_tcgen5mma.py
examples/gemm_sm100/gemm_tcgen5mma.py
+91
-0
examples/gemm_sp/example_gemm_sp.py
examples/gemm_sp/example_gemm_sp.py
+152
-0
examples/gemm_splitk/example_tilelang_gemm_splitk.py
examples/gemm_splitk/example_tilelang_gemm_splitk.py
+71
-0
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
...plitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
+70
-0
examples/gemm_splitk/test_example_gemm_splitk.py
examples/gemm_splitk/test_example_gemm_splitk.py
+16
-0
examples/gemm_streamk/example_tilelang_gemm_streamk.py
examples/gemm_streamk/example_tilelang_gemm_streamk.py
+205
-0
examples/gemm_streamk/test_example_tilelang_gemm_splitk.py
examples/gemm_streamk/test_example_tilelang_gemm_splitk.py
+14
-0
No files found.
Too many changes to show.
To preserve performance only
257 of 257+
files are displayed.
Plain diff
Email patch
examples/gemm/example_gemm_schedule.py
0 → 100644
View file @
bc2d5632
import
tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_schedule
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Enable rasterization for better L2 Cache Locality
T
.
use_swizzle
(
panel_size
=
10
)
# Clear the local buffer
T
.
clear
(
C_local
)
# Auto pipeline the computation
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Instead of using
# T.copy(B[k * block_K, bx * block_N], B_shared)
# we can also use Parallel to auto map the thread
# bindings and vectorize the copy operation.
for
k
,
j
in
T
.
Parallel
(
block_K
,
block_N
):
B_shared
[
k
,
j
]
=
B
[
ko
*
block_K
+
k
,
bx
*
block_N
+
j
]
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm_schedule
def
main
():
kernel
=
matmul
(
1024
,
1024
,
1024
,
128
,
128
,
32
)
import
torch
a
=
torch
.
randn
(
1024
,
1024
).
cuda
().
half
()
b
=
torch
.
randn
(
1024
,
1024
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
ref_c
=
a
@
b
print
(
"c:"
)
print
(
c
)
print
(
"ref_c:"
)
print
(
ref_c
)
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All check passed."
)
# Get CUDA Source
print
(
"CUDA Source:"
)
print
(
kernel
.
get_kernel_source
())
if
__name__
==
"__main__"
:
main
()
examples/gemm/test_example_gemm.py
0 → 100644
View file @
bc2d5632
import
tilelang.testing
import
example_gemm_autotune
import
example_gemm_intrinsics
import
example_gemm_schedule
import
example_gemm
def
test_example_gemm_autotune
():
# enable roller for fast tuning
example_gemm_autotune
.
main
(
M
=
1024
,
N
=
1024
,
K
=
1024
,
with_roller
=
True
)
def
test_example_gemm_intrinsics
():
example_gemm_intrinsics
.
main
(
M
=
1024
,
N
=
1024
,
K
=
1024
)
def
test_example_gemm_schedule
():
example_gemm_schedule
.
main
()
def
test_example_gemm
():
example_gemm
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/gemm_fp8/README.md
0 → 100644
View file @
bc2d5632
**Notes**
: Now we only support fp8 with mma instructions instead of
`T.gemm`
, because the cutlass version of tilelang is too old, we should update the cutlass version in future.
\ No newline at end of file
examples/gemm_fp8/example_tilelang_gemm_amd.py
0 → 100644
View file @
bc2d5632
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
torch_assert_close
import
itertools
def
ref_program
(
A
,
B
):
return
(
A
.
half
()
@
B
.
half
().
T
).
to
(
dtype
=
torch
.
float32
)
def
manual_check_prog
(
C
,
C_ref
):
torch_assert_close
(
C
[
0
],
C_ref
[
0
],
rtol
=
0.01
,
atol
=
0.1
)
def
supply_prog
(
args
):
a_param
,
b_param
=
args
M
,
K
=
a_param
.
shape
N
,
_
=
b_param
.
shape
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
return
[
a
,
b
]
def
get_configs
():
block_Ms
=
[
32
,
64
,
128
]
block_Ns
=
[
32
,
64
,
128
]
block_Ks
=
[
64
,
128
]
num_stages
=
[
0
]
num_threads
=
[
256
]
k_packs
=
[
1
,
2
]
gemm_types
=
[
"ss"
,
"rs"
]
valid_configs
=
[]
for
m
,
n
,
k
,
stages
,
t
,
kp
,
gemm_type
in
itertools
.
product
(
block_Ms
,
block_Ns
,
block_Ks
,
num_stages
,
num_threads
,
k_packs
,
gemm_types
):
valid_configs
.
append
({
"block_M"
:
m
,
"block_N"
:
n
,
"block_K"
:
k
,
"num_stages"
:
stages
,
"num_threads"
:
t
,
"k_pack"
:
kp
,
"gemm_type"
:
gemm_type
,
})
return
valid_configs
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
,
ref_prog
=
ref_program
,
manual_check_prog
=
manual_check_prog
,
supply_prog
=
supply_prog
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
fp8_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
,
k_pack
,
gemm_type
):
dtype
=
"float8_e4m3fnuz"
accum_dtype
=
"float"
@
T
.
prim_func
def
gemm_fp8_rs
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_fragment
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_local
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_local
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
@
T
.
prim_func
def
gemm_fp8_ss
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
k_pack
=
k_pack
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
if
gemm_type
==
"ss"
:
return
gemm_fp8_ss
elif
gemm_type
==
"rs"
:
return
gemm_fp8_rs
else
:
raise
ValueError
(
f
"Invalid gemm_type:
{
gemm_type
}
"
)
def
test_gemm_fp8
(
M
,
N
,
K
):
kernel
=
fp8_matmul
(
M
,
N
,
K
)
a
=
(
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
b
=
(
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
*
0.01
).
to
(
dtype
=
torch
.
float8_e4m3fnuz
)
c
=
kernel
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
torch_assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"passed~"
)
if
__name__
==
"__main__"
:
test_gemm_fp8
(
512
,
512
,
512
)
examples/gemm_fp8/example_tilelang_gemm_fp8.py
0 → 100644
View file @
bc2d5632
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
map_torch_type
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm_fp8
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm_fp8
def
test_gemm_fp8
(
M
,
N
,
K
,
dtype
):
torch_dtype
=
map_torch_type
(
dtype
)
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
half
()
@
b
.
half
().
T
).
to
(
dtype
=
torch_dtype
)
print
(
c
)
print
(
ref_c
)
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
assert
diff
<
1e-3
def
main
():
test_gemm_fp8
(
1024
,
1024
,
1024
,
'float8_e4m3'
)
test_gemm_fp8
(
1024
,
1024
,
1024
,
'float8_e5m2'
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
0 → 100644
View file @
bc2d5632
import
torch
import
tilelang
import
tilelang.language
as
T
from
tilelang.utils.tensor
import
map_torch_type
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
=
"float"
):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
update_interval
=
128
//
block_K
if
block_K
<
128
else
1
@
T
.
prim_func
def
gemm_fp8_2xAcc
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local_accum
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local_accum
)
K_iters
=
T
.
ceildiv
(
K
,
block_K
)
for
k
in
T
.
Pipelined
(
K_iters
,
num_stages
=
3
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
)
# Promote to enable 2xAcc
if
(
k
+
1
)
%
update_interval
==
0
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C_local_accum
[
i
,
j
]
+=
C_local
[
i
,
j
]
T
.
clear
(
C_local
)
# Tail processing
if
K_iters
%
update_interval
!=
0
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C_local_accum
[
i
,
j
]
+=
C_local
[
i
,
j
]
# TMA store
T
.
copy
(
C_local_accum
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm_fp8_2xAcc
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
test_gemm_fp8
(
M
,
N
,
K
,
dtype
):
torch_dtype
=
map_torch_type
(
dtype
)
kernel
=
matmul
(
M
,
N
,
K
,
128
,
128
,
64
,
dtype
)
a
=
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
a
=
(
100
*
(
2
*
a
-
1
)).
to
(
dtype
=
torch_dtype
)
b
=
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
b
=
(
100
*
(
2
*
b
-
1
)).
to
(
dtype
=
torch_dtype
)
c
=
kernel
(
a
,
b
)
ref_c
=
(
a
.
float
()
@
b
.
float
().
T
)
diff
=
calc_diff
(
c
,
ref_c
)
print
(
f
"diff:
{
diff
}
"
)
assert
diff
<
1e-3
def
main
():
test_gemm_fp8
(
1024
,
1024
,
8192
,
'float8_e4m3'
)
test_gemm_fp8
(
1024
,
1024
,
8192
,
'float8_e5m2'
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
0 → 100644
View file @
bc2d5632
import
torch
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
from
tvm
import
DataType
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
tilelang
.
testing
.
set_random_seed
(
0
)
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
can_swizzle
=
shape
[
-
1
]
*
DataType
(
dtype
).
bits
==
512
if
not
can_swizzle
:
return
T
.
Layout
(
shape
,
lambda
*
args
:
args
)
def
transform_func
(
i
,
j
):
new_warp_i
,
new_warp_j
=
get_swizzle_layout
(
i
,
j
,
shape
[
-
1
],
dtype
)
return
[
new_warp_i
,
new_warp_j
]
return
T
.
Layout
(
shape
,
transform_func
)
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
simplify_prim_func
def
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
):
assert
in_dtype
in
[
"float16"
,
"float8_e4m3"
,
"float8_e5m2"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
is_float8
=
in_dtype
in
[
"float8_e4m3"
,
"float8_e5m2"
]
if
out_dtype
==
"int32"
or
is_float8
:
micro_size_k
=
32
# This is a debug config
block_row_warps
=
2
block_col_warps
=
2
warp_row_tiles
=
32
warp_col_tiles
=
32
chunk
=
32
if
in_dtype
==
"float16"
else
64
shared_scope
=
"shared.dyn"
# Pipeline Stage
stage
=
2
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
block_K
=
chunk
A_shape
=
(
M
,
K
)
B_shape
=
(
N
,
K
)
A_shared_shape
=
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
C_shared_shape
=
(
block_M
//
micro_size_x
,
block_N
//
micro_size_y
,
micro_size_x
,
micro_size_y
,
)
warp_size
=
32
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
local_size_a
=
(
micro_size_x
*
micro_size_k
)
//
warp_size
local_size_b
=
(
micro_size_y
*
micro_size_k
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
warp_rows
=
warp_row_tiles
//
micro_size_x
warp_cols
=
warp_col_tiles
//
micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter
=
TensorCoreIntrinEmitter
(
a_dtype
=
in_dtype
,
b_dtype
=
in_dtype
,
accum_dtype
=
accum_dtype
,
a_transposed
=
False
,
b_transposed
=
True
,
block_row_warps
=
block_row_warps
,
block_col_warps
=
block_col_warps
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
chunk
,
)
@
T
.
prim_func
def
gemm_fp8_intrinsic
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
# Load B into shared memory
for
j
,
k
in
T
.
Parallel
(
block_N
,
block_K
):
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Load B into fragment
mma_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
,
)
# Perform Matrix Multiplication
mma_emitter
.
mma
(
A_local
,
B_local
,
C_local
)
# Perform STMatrix
mma_emitter
.
stmatrix
(
C_local
,
C_shared
,
)
# Store shared into global
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
i
//
micro_size_x
,
j
//
micro_size_y
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
return
gemm_fp8_intrinsic
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
):
kernel
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
)
src_code
=
kernel
.
get_kernel_source
()
print
(
src_code
)
# src_code is the generated cuda source
assert
src_code
is
not
None
in_dtype
=
map_torch_type
(
in_dtype
)
out_dtype
=
map_torch_type
(
out_dtype
)
accum_dtype
=
map_torch_type
(
accum_dtype
)
if
in_dtype
in
{
torch
.
int8
,
torch
.
int32
}:
A
=
torch
.
randint
(
-
128
,
128
,
(
M
,
K
),
dtype
=
torch
.
int8
).
to
(
in_dtype
).
cuda
()
B
=
torch
.
randint
(
-
128
,
128
,
(
N
,
K
),
dtype
=
torch
.
int8
).
to
(
in_dtype
).
cuda
()
elif
in_dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
}:
A
=
torch
.
randn
(
M
,
K
).
to
(
in_dtype
).
cuda
()
B
=
torch
.
randn
(
N
,
K
).
to
(
in_dtype
).
cuda
()
else
:
A
=
torch
.
randn
(
M
,
K
).
to
(
in_dtype
).
cuda
()
-
0.5
B
=
torch
.
randn
(
N
,
K
).
to
(
in_dtype
).
cuda
()
-
0.5
C
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
accum_dtype
)
profiler
=
kernel
.
get_profiler
(
tilelang
.
TensorSupplyType
.
Integer
)
C
=
profiler
(
A
,
B
)
latency
=
profiler
.
do_bench
(
warmup
=
25
)
# Ensure that the latency is not None
assert
latency
is
not
None
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
accum_dtype
),
B
.
T
.
to
(
accum_dtype
)).
to
(
out_dtype
)
print
(
C
)
print
(
ref_c
)
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
main
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float8_e4m3"
,
"float32"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float8_e5m2"
,
"float32"
,
"float32"
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_fp8/test_example_gemm_fp8.py
0 → 100644
View file @
bc2d5632
import
tilelang.testing
import
example_tilelang_gemm_fp8_2xAcc
import
example_tilelang_gemm_fp8_intrinsic
import
example_tilelang_gemm_fp8
def
test_example_tilelang_gemm_fp8_2xAcc
():
example_tilelang_gemm_fp8_2xAcc
.
main
()
def
test_example_tilelang_gemm_fp8_intrinsic
():
example_tilelang_gemm_fp8_intrinsic
.
main
()
def
test_example_tilelang_gemm_fp8
():
example_tilelang_gemm_fp8
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/gemm_sm100/README.md
0 → 100644
View file @
bc2d5632
# TileLang SM100 Support (Preview)
This directory contains examples for TileLang's experimental SM100 architecture support.
**This is a preview version**
with limited functionality.
## Current Limitations (Manual Implementation Required)
### 1. Manual TCGEN5.MMA Management
Users must manually handle TCGEN5MMA operations using:
-
`T.alloc_tmem()`
- Allocate Tensor Memory
-
`T.gemm()`
with
`wg_wait=-1`
- Launch TCGEN5MMA without waiting
-
Manual synchronization with mbarrier
### 2. Manual mbarrier Synchronization
TCGEN5MMA is asynchronous and requires explicit synchronization:
```
python
mbar
=
T
.
alloc_barrier
(
1
)
# expect-arrive-count = 1
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
# Manual phase calculation required
```
## Examples
### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Demonstrates TCGEN5MMA operations with:
-
Tensor Memory allocation
-
Manual mbarrier synchronization
-
TCGEN5MMA gemm operations
### Traditional MMA Example (`gemm_mma.py`)
Shows standard MMA operations that work across architectures for comparison.
## Code Example
The following code is based on
`gemm_tcgen5mma.py`
, demonstrating TCGEN5MMA matrix multiplication:
```
python
import
torch
import
tilelang
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
"bfloat16"
),
B
:
T
.
Tensor
((
N
,
K
),
"bfloat16"
),
C
:
T
.
Tensor
((
M
,
N
),
"bfloat16"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
# 1. Allocate memory buffers
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
"bfloat16"
)
# A matrix shared memory
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
"bfloat16"
)
# B matrix shared memory
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
"float"
)
# TCGEN5MMA output to Tensor Memory
mbar
=
T
.
alloc_barrier
(
1
)
# mbarrier synchronization primitive
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
"float"
)
# Register storage
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
"bfloat16"
)
# Output shared memory
# 2. Main computation loop
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
1
):
# Data loading: global memory to shared memory
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
=
False
,
trans_B
=
True
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
# Critical: wait for TCGEN5MMA completion
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
# 3. Output processing (only subset of threads)
T
.
copy
(
C_tmem
,
C_local
)
# Tensor Memory → registers
T
.
copy
(
C_local
,
C_shared
)
# registers → shared memory
# 4. Write back to global memory
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
```
### Compilation and Usage
```
python
# Parameter setup
M
,
N
,
K
=
4096
,
4096
,
8192
block_M
,
block_N
,
block_K
=
128
,
256
,
128
# Compile kernel
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
# Required
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# Required
})
# Run test
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
c
=
jit_kernel
(
a
,
b
)
# Verify correctness
ref_c
=
(
a
.
to
(
torch
.
float
)
@
b
.
T
.
to
(
torch
.
float
)).
to
(
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
# Performance benchmark
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Performance:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
:.
2
f
}
TFLOPS"
)
```
examples/gemm_sm100/gemm_mma.py
0 → 100644
View file @
bc2d5632
import
tilelang
import
tilelang.language
as
T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Clear local accumulation
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
0
):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Copy tile of B
T
.
copy
(
B
[
bx
*
block_N
,
ko
*
block_K
],
B_shared
)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
)
# Copy result back to global memory
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
M
=
128
# M = T.dynamic("m") if you want to use dynamic shape
N
=
128
K
=
32
block_M
=
128
block_N
=
128
block_K
=
32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
print
(
jit_kernel
.
get_kernel_source
())
# 3. Test the kernel in Python with PyTorch data
import
torch
# Create random input tensors on the GPU
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Run the kernel through the Profiler
c
=
jit_kernel
(
a
,
b
)
print
(
c
)
# Reference multiplication using PyTorch
ref_c
=
a
@
b
.
T
# Validate correctness
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"Kernel output matches PyTorch reference."
)
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler
=
jit_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
examples/gemm_sm100/gemm_tcgen5mma.py
0 → 100644
View file @
bc2d5632
import
torch
import
tilelang
import
tilelang.language
as
T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
accum_dtype
)
mbar
=
T
.
alloc_barrier
(
1
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
M
,
N
,
K
=
4096
,
4096
,
8192
block_M
,
block_N
,
block_K
=
128
,
256
,
128
trans_A
,
trans_B
=
False
,
True
in_dtype
,
out_dtype
,
accum_dtype
=
"bfloat16"
,
"bfloat16"
,
"float"
num_stages
=
2
threads
=
256
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
print
(
jit_kernel
.
get_kernel_source
())
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
c
=
jit_kernel
(
a
,
b
)
ref_c
=
(
a
.
to
(
torch
.
float
)
@
b
.
T
.
to
(
torch
.
float
)).
to
(
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
profiler
=
jit_kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
print
(
f
"Flops:
{
2
*
M
*
N
*
K
/
(
latency
/
1e3
)
/
1e12
}
TFLOPS"
)
examples/gemm_sp/example_gemm_sp.py
0 → 100644
View file @
bc2d5632
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import
argparse
import
tilelang
import
tilelang.language
as
T
from
tilelang.layout
import
make_metadata_layout
from
tilelang.utils.sparse
import
compress
,
randn_semi_sparse
from
tilelang.contrib
import
nvcc
from
triton.testing
import
do_bench
import
torch
arch
=
nvcc
.
get_target_compute_version
()
ARCH_INFO
=
{
"8.0"
:
(
16
,
"int16"
),
"8.9"
:
(
16
,
"int16"
),
"9.0"
:
(
8
,
"uint8"
)}
default_config
=
{
# take best config from autotune script
"4090"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
64
,
'num_stages'
:
1
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
},
'float16'
:
{
'block_M'
:
256
,
'block_N'
:
128
,
'block_K'
:
64
,
'num_stages'
:
2
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
},
"h20"
:
{
'float'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
},
'float16'
:
{
'block_M'
:
128
,
'block_N'
:
64
,
'block_K'
:
128
,
'num_stages'
:
3
,
'thread_num'
:
128
,
'policy'
:
T
.
GemmWarpPolicy
.
Square
,
'enable_rasterization'
:
True
}
}
}
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul_sp_fp16
(
M
,
N
,
K
,
accum_dtype
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
policy
,
enable_rasterization
):
e_factor
,
e_dtype
=
ARCH_INFO
[
arch
]
@
T
.
prim_func
def
gemm_sp_fp16
(
A_sparse
:
T
.
Tensor
((
M
,
K
//
2
),
'float16'
),
E
:
T
.
Tensor
((
M
,
K
//
e_factor
),
e_dtype
),
B
:
T
.
Tensor
((
K
,
N
),
'float16'
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
2
),
'float16'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
e_factor
),
e_dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
'float16'
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
disable_warp_group_reg_alloc
()
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasterization
)
T
.
annotate_layout
({
E
:
make_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
backend
=
"cutlass"
,
block_k
=
block_K
,
arch
=
arch
),
E_shared
:
make_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
backend
=
"cutlass"
,
block_k
=
block_K
,
arch
=
arch
),
})
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A_sparse
[
by
*
block_M
,
k
*
block_K
//
2
],
A_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
e_factor
],
E_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm_sp
(
A_shared
,
E_shared
,
B_shared
,
C_local
,
False
,
False
,
policy
=
policy
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm_sp_fp16
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
16384
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float"
,
choices
=
[
"float"
,
"float16"
],
help
=
"Accumulation datatype"
)
parser
.
add_argument
(
"--cfg"
,
type
=
str
,
choices
=
[
"4090"
,
"h20"
],
required
=
True
)
args
=
parser
.
parse_args
()
kernel
=
matmul_sp_fp16
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
accum_dtype
,
**
default_config
[
args
.
cfg
][
args
.
accum_dtype
])
a
=
randn_semi_sparse
(
args
.
m
,
args
.
k
,
device
=
'cuda'
,
dtype
=
torch
.
half
)
b
=
torch
.
randn
(
args
.
k
,
args
.
n
,
device
=
'cuda'
,
dtype
=
torch
.
half
)
a_sparse
,
e
=
compress
(
a
,
transposed
=
False
,
block_k
=
default_config
[
args
.
cfg
][
args
.
accum_dtype
][
'block_K'
],
arch
=
arch
)
c
=
kernel
(
a_sparse
,
e
,
b
)
ref_c
=
a
@
b
assert
not
c
.
isnan
().
any
(),
"Reference result contains NaNs, please report an issue"
torch
.
testing
.
assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
f
"Precision check passed. diff:
{
(
c
-
ref_c
).
abs
().
mean
()
}
"
)
latency
=
do_bench
(
lambda
:
kernel
(
a_sparse
,
e
,
b
))
ref_latency
=
do_bench
(
lambda
:
a
@
b
)
total_flops
=
2
*
args
.
m
*
args
.
n
*
args
.
k
tflops
=
total_flops
/
latency
/
1e9
ref_tflops
=
total_flops
/
ref_latency
/
1e9
print
(
f
"Sparse TFLOPS:
{
tflops
:.
2
f
}
, Latency:
{
latency
/
1e3
}
s"
)
print
(
f
"Reference TFLOPS:
{
ref_tflops
:.
2
f
}
, Latency:
{
ref_latency
/
1e3
:
}
s"
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_splitk/example_tilelang_gemm_splitk.py
0 → 100644
View file @
bc2d5632
import
tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
splitK
,
block_K
),
num_stages
=
0
):
T
.
copy
(
A
[
by
*
block_M
,
bz
*
splitK
+
ko
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bz
*
splitK
+
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
C_shared
[
i
,
j
])
return
main
def
main
():
M
=
1024
N
=
1024
K
=
1024
block_M
=
128
block_N
=
128
block_K
=
32
split_k
=
4
kernel
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
)
import
torch
torch
.
random
.
manual_seed
(
42
)
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
c
=
torch
.
zeros
(
M
,
N
).
cuda
().
float
()
kernel
(
a
,
b
,
c
)
ref_c
=
a
@
b
torch
.
testing
.
assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
0 → 100644
View file @
bc2d5632
import
tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
out_dtype
=
"float32"
):
splitK
=
K
//
split_k
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
split_k
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
splitK
,
block_K
),
num_stages
=
0
):
T
.
copy
(
A
[
by
*
block_M
,
bz
*
splitK
+
ko
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bz
*
splitK
+
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
T
.
atomic_add
(
C
[
by
*
block_M
,
bx
*
block_N
],
C_shared
)
return
main
def
main
():
M
=
1024
N
=
1024
K
=
1024
block_M
=
128
block_N
=
128
block_K
=
32
split_k
=
4
kernel
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
split_k
)
import
torch
torch
.
random
.
manual_seed
(
42
)
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
c
=
torch
.
zeros
(
M
,
N
).
cuda
().
float
()
kernel
(
a
,
b
,
c
)
ref_c
=
a
@
b
torch
.
testing
.
assert_close
(
c
,
ref_c
.
to
(
c
.
dtype
),
rtol
=
1e-2
,
atol
=
1e-2
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_splitk/test_example_gemm_splitk.py
0 → 100644
View file @
bc2d5632
import
tilelang.testing
import
example_tilelang_gemm_splitk
import
example_tilelang_gemm_splitk_vectorize_atomicadd
def
test_example_tilelang_gemm_splitk
():
example_tilelang_gemm_splitk
.
main
()
def
test_example_tilelang_gemm_splitk_vectorize_atomicadd
():
example_tilelang_gemm_splitk_vectorize_atomicadd
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/gemm_streamk/example_tilelang_gemm_streamk.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.backends
import
tilelang
from
tilelang
import
language
as
T
import
math
def
cdiv
(
a
,
b
):
return
math
.
ceil
(
a
/
b
)
# disable tf32
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
m
=
256
n
=
1024
k
=
512
total_sm
=
108
torch
.
random
.
manual_seed
(
0
)
# uniform distribution from -1 to 1
A
=
torch
.
rand
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
*
2
-
1
B
=
torch
.
rand
(
n
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
*
2
-
1
streamk_programs
=
total_sm
BLOCK_SIZE_M
=
16
BLOCK_SIZE_N
=
128
BLOCK_SIZE_K
=
32
two_tiles
=
False
M
,
K
=
A
.
shape
N
,
K
=
B
.
shape
# accumulator types
# compute grid (work to do per SM on the first wave)
num_block_m
=
tilelang
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_block_n
=
tilelang
.
cdiv
(
N
,
BLOCK_SIZE_N
)
iters_per_tile
=
tilelang
.
cdiv
(
K
,
BLOCK_SIZE_K
)
total_tiles
=
num_block_m
*
num_block_n
# Two-tile SK + DP
streamk_tiles
=
total_tiles
%
streamk_programs
if
(
total_tiles
-
streamk_tiles
>
streamk_programs
):
# (total_tiles // total_programs > 1)
streamk_tiles
+=
streamk_programs
blocking_tiles
=
total_tiles
-
streamk_tiles
streamk_iters
=
streamk_tiles
*
iters_per_tile
streamk_full_tiles
=
streamk_iters
//
streamk_programs
streamk_partial_tiles
=
streamk_iters
%
streamk_programs
print
(
f
"
{
total_tiles
=
}
"
)
print
(
f
"
{
iters_per_tile
=
}
"
)
sm_patition_factor
=
max
(
blocking_tiles
//
total_sm
,
1
)
@
tilelang
.
jit
def
tl_matmul_streamk
(
M
,
N
,
K
,
streamk_tiles
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
dtypeAB
,
dtypeC
,
accum_dtype
,
num_stages
,
threads
,
):
assert
not
trans_A
A_shape
=
(
M
,
K
)
if
not
trans_A
else
(
K
,
M
)
B_shape
=
(
K
,
N
)
if
not
trans_B
else
(
N
,
K
)
A_shared_shape
=
(
block_M
,
block_K
)
if
not
trans_A
else
(
block_K
,
block_M
)
B_shared_shape
=
(
block_K
,
block_N
)
if
not
trans_B
else
(
block_N
,
block_K
)
@
T
.
macro
def
compute_first_wave
(
pid
:
T
.
int32
,
A_buf
:
T
.
Tensor
,
A_buf_shared
:
T
.
SharedBuffer
,
B_buf
:
T
.
Tensor
,
B_buf_shared
:
T
.
SharedBuffer
,
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
start_iter
=
T
.
alloc_fragment
((
1
,),
"int32"
,
"local"
)
end_iter
=
T
.
alloc_fragment
((
1
,),
"int32"
,
"local"
)
start_iter
[
0
]
=
pid
*
streamk_full_tiles
+
T
.
min
(
pid
,
streamk_partial_tiles
)
last_iter
=
(
pid
+
1
)
*
streamk_full_tiles
+
T
.
min
(
pid
+
1
,
streamk_partial_tiles
)
while
start_iter
[
0
]
<
last_iter
:
end_iter
[
0
]
=
T
.
min
(
start_iter
[
0
]
+
(
iters_per_tile
-
(
start_iter
[
0
]
%
iters_per_tile
)),
last_iter
,
)
tile_id
=
start_iter
[
0
]
//
iters_per_tile
remain_iters
=
start_iter
[
0
]
%
iters_per_tile
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
pid_n
=
tile_id
%
T
.
ceildiv
(
N
,
block_N
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
end_iter
[
0
]
-
start_iter
[
0
],
num_stages
=
num_stages
):
T
.
copy
(
A_buf
[
pid_m
*
block_M
,
(
k
+
(
start_iter
[
0
]
%
iters_per_tile
))
*
block_K
],
A_buf_shared
,
)
T
.
copy
(
B_buf
[
pid_n
*
block_N
,
(
k
+
(
start_iter
[
0
]
%
iters_per_tile
))
*
block_K
],
B_buf_shared
,
)
T
.
gemm
(
A_buf_shared
,
B_buf_shared
,
C_local
,
transpose_B
=
trans_B
)
# last iteration of the tile always happens before its start on another SM
if
remain_iters
==
0
and
(
end_iter
[
0
]
%
iters_per_tile
==
0
):
T
.
copy
(
C_local
,
C
[
pid_m
*
block_M
,
pid_n
*
block_N
])
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
C
[
pid_m
*
block_M
+
i
,
pid_n
*
block_N
+
j
],
C_local
[
i
,
j
])
start_iter
[
0
]
=
end_iter
[
0
]
@
T
.
macro
def
compute_full_tiles
(
pid
:
T
.
int32
,
A_buf
:
T
.
Tensor
,
A_shared
:
T
.
SharedBuffer
,
B_buf
:
T
.
Tensor
,
B_shared
:
T
.
SharedBuffer
,
C
:
T
.
Tensor
,
C_local
:
T
.
LocalBuffer
,
):
for
p
in
T
.
serial
(
sm_patition_factor
):
tile_id
=
pid
+
streamk_tiles
+
p
*
total_sm
pid_m
=
tile_id
//
T
.
ceildiv
(
N
,
block_N
)
pid_n
=
tile_id
%
T
.
ceildiv
(
N
,
block_N
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
1
):
T
.
copy
(
A_buf
[
pid_m
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B_buf
[
pid_n
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
trans_B
)
T
.
copy
(
C_local
,
C
[
pid_m
*
block_M
,
pid_n
*
block_N
])
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
dtypeAB
),
B
:
T
.
Tensor
(
B_shape
,
dtypeAB
),
C
:
T
.
Tensor
((
M
,
N
),
dtypeC
),
):
with
T
.
Kernel
(
streamk_programs
,
threads
=
threads
)
as
pid
:
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
A_shared_full_tiles
=
T
.
alloc_shared
(
A_shared_shape
,
dtypeAB
)
B_shared_full_tiles
=
T
.
alloc_shared
(
B_shared_shape
,
dtypeAB
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
compute_first_wave
(
pid
,
A
,
A_shared
,
B
,
B_shared
,
C
,
C_local
)
if
sm_patition_factor
>
0
:
compute_full_tiles
(
pid
,
A
,
A_shared_full_tiles
,
B
,
B_shared_full_tiles
,
C
,
C_local
)
return
main
def
main
():
kernel
=
tl_matmul_streamk
(
m
,
n
,
k
,
streamk_tiles
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
BLOCK_SIZE_K
,
False
,
True
,
"float16"
,
"float16"
,
"float32"
,
2
,
64
,
)
print
(
kernel
.
get_kernel_source
())
b_c
=
torch
.
zeros
((
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
kernel
(
A
,
B
,
b_c
)
C
=
torch
.
matmul
(
A
,
B
.
T
)
print
(
b_c
)
print
(
C
)
torch
.
testing
.
assert_close
(
C
,
b_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
if
__name__
==
"__main__"
:
main
()
examples/gemm_streamk/test_example_tilelang_gemm_splitk.py
0 → 100644
View file @
bc2d5632
import
tilelang.testing
from
example_tilelang_gemm_streamk
import
main
# not fully supported on sm90
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_le
(
8
,
9
)
def
test_example_tilelang_gemm_streamk
():
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment