Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
291 deletions
+215
-291
examples/warp_specialize/example_warp_specialize_flashmla.py
examples/warp_specialize/example_warp_specialize_flashmla.py
+53
-91
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
...ialize/example_warp_specialize_gemm_barrierpipe_stage2.py
+5
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
..._specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
..._specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
..._specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
+8
-14
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
...pecialize/example_warp_specialize_gemm_softpipe_stage2.py
+0
-1
format.sh
format.sh
+1
-1
maint/gemm_v2/correctness_evaluation.py
maint/gemm_v2/correctness_evaluation.py
+56
-44
maint/gemm_v2/correctness_evaluation_sm70.py
maint/gemm_v2/correctness_evaluation_sm70.py
+14
-11
maint/gemm_v2/correctness_evaluation_tcgen05.py
maint/gemm_v2/correctness_evaluation_tcgen05.py
+12
-17
maint/gemm_v2/latency.py
maint/gemm_v2/latency.py
+3
-4
maint/gemm_v2/latency_gemm.py
maint/gemm_v2/latency_gemm.py
+3
-4
maint/gemm_v2/latency_mha_fwd_bhsd.py
maint/gemm_v2/latency_mha_fwd_bhsd.py
+40
-58
maint/host_checks/01_num_args_mismatch.py
maint/host_checks/01_num_args_mismatch.py
+1
-0
maint/host_checks/02_pointer_type_error.py
maint/host_checks/02_pointer_type_error.py
+1
-0
maint/host_checks/03_ndim_mismatch.py
maint/host_checks/03_ndim_mismatch.py
+2
-2
maint/host_checks/04_dtype_mismatch.py
maint/host_checks/04_dtype_mismatch.py
+2
-2
maint/host_checks/05_shape_mismatch.py
maint/host_checks/05_shape_mismatch.py
+2
-2
maint/host_checks/06_strides_mismatch.py
maint/host_checks/06_strides_mismatch.py
+2
-2
maint/host_checks/07_device_type_mismatch.py
maint/host_checks/07_device_type_mismatch.py
+2
-2
No files found.
examples/warp_specialize/example_warp_specialize_flashmla.py
View file @
29051439
...
...
@@ -9,7 +9,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
])
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
...
...
@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
})
}
)
# barriers_Q
q_shared_ready_barrier
=
T
.
alloc_barrier
(
arrive_count
=
256
)
...
...
@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
barrier_arrive
(
q_shared_ready_barrier
)
T
.
barrier_wait
(
q_shared_ready_barrier
,
0
)
...
...
@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
fill
(
acc_o_l
,
0
)
T
.
fill
(
logsum_0
,
0
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
T
.
barrier_wait
(
kv_shared_0_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_0_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s_0
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s_0
[
i
,
j
]
=
T
.
exp2
(
acc_s_0
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s_0
,
scores_sum_0
,
dim
=
1
)
...
...
@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_wait
(
scale_1_ready_barrier
,
k
%
2
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
barrier_arrive
(
kv_shared_0_l_is_ready
)
# Step 11.
...
...
@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP1_shared
,
KV_shared_1_l
,
acc_o_l
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
copy
(
logsum_0
,
logsum
)
...
...
@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_l
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
else
:
T
.
copy
(
Q_pe_shared
,
Q_pe_local_1
)
...
...
@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
# Step 2.
T
.
barrier_wait
(
kv_shared_1_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_1_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s_1
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
scores_max_1
,
scores_max
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
# Step 8.
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
i
,
j
]
=
acc_o_r
[
i
,
j
]
*
(
scores_scale_0
[
i
]
*
scores_scale_1
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
T
.
barrier_arrive
(
scale_1_ready_barrier
)
...
...
@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
s_shared_ready_barrier
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_wait
(
p0_1_1_ready_barrier
,
k
%
2
)
...
...
@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP0_shared
,
KV_shared_0_r
,
acc_o_r
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
barrier_arrive
(
kv_shared_0_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_wait
(
lse_0_ready_barrier
,
0
)
...
...
@@ -319,8 +289,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_r
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
@
T
.
prim_func
def
main_no_split
(
...
...
@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ tilelang.disable_cache()
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
num_stages
=
2
mbarrier_list
=
[
128
,
128
]
*
num_stages
...
...
@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for
ko
in
range
(
T
.
ceildiv
(
K
,
block_K
)):
with
T
.
ws
(
1
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:(
by
+
1
)
*
block_M
,
ko
*
block_K
:(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:(
ko
+
1
)
*
block_K
,
bx
*
block_N
:(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:
(
by
+
1
)
*
block_M
,
ko
*
block_K
:
(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:
(
ko
+
1
)
*
block_K
,
bx
*
block_N
:
(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
)
with
T
.
ws
(
0
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
+
num_stages
)
with
T
.
ws
(
0
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
View file @
29051439
...
...
@@ -5,15 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
View file @
29051439
...
...
@@ -5,15 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
View file @
29051439
...
...
@@ -5,18 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
},
)
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
warp_group_num
=
2
threads
=
128
*
warp_group_num
...
...
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ import tilelang.language as T
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,
K
),
dtype
],
...
...
format.sh
View file @
29051439
...
...
@@ -9,7 +9,7 @@
# bash format.sh --all
#
#
#
YAPF
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
#
Ruff (format)
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
...
...
maint/gemm_v2/correctness_evaluation.py
View file @
29051439
...
...
@@ -66,7 +66,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
FALSE_TRUE_CASES
=
(
[
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"int8"
,
"int32"
,
"int32"
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
)
for
k
in
K_VALUES_8Bit
]
)
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_sm70.py
View file @
29051439
...
...
@@ -67,7 +67,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -213,14 +214,15 @@ def run_gemm_rs(
M_VALUES
=
[
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
...
...
@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
"float16"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
)
for
k
in
K_VALUES
])
)
for
k
in
K_VALUES
]
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_tcgen05.py
View file @
29051439
...
...
@@ -42,15 +42,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES
=
[
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"float16"
,
"float32"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float-float"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
...
...
@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
)
for
k
in
K_VALUES_8Bit
]
TRANS_CASES
=
[
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
...
...
maint/gemm_v2/latency.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
maint/gemm_v2/latency_gemm.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
maint/gemm_v2/latency_mha_fwd_bhsd.py
View file @
29051439
...
...
@@ -8,13 +8,13 @@ import argparse
from
functools
import
partial
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
16
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
1024
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
1024
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
256
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
16
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
1024
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
1024
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
256
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
...
...
@@ -29,20 +29,13 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
...
...
@@ -62,7 +55,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
...
...
@@ -85,7 +78,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if
use_v2
:
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_
M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -206,18 +198,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
print
(
kernel
.
get_kernel_source
())
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
...
...
maint/host_checks/01_num_args_mismatch.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/02_pointer_type_error.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/03_ndim_mismatch.py
View file @
29051439
"""Reproduce: ndim (rank) mismatch for A.
"""
"""Reproduce: ndim (rank) mismatch for A.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/04_dtype_mismatch.py
View file @
29051439
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/05_shape_mismatch.py
View file @
29051439
"""Reproduce: shape constant/symbol mismatch on A.
"""
"""Reproduce: shape constant/symbol mismatch on A.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/06_strides_mismatch.py
View file @
29051439
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/07_device_type_mismatch.py
View file @
29051439
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment