Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
291 deletions
+215
-291
examples/warp_specialize/example_warp_specialize_flashmla.py
examples/warp_specialize/example_warp_specialize_flashmla.py
+53
-91
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
...ialize/example_warp_specialize_gemm_barrierpipe_stage2.py
+5
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
..._specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
..._specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
..._specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
+8
-14
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
...pecialize/example_warp_specialize_gemm_softpipe_stage2.py
+0
-1
format.sh
format.sh
+1
-1
maint/gemm_v2/correctness_evaluation.py
maint/gemm_v2/correctness_evaluation.py
+56
-44
maint/gemm_v2/correctness_evaluation_sm70.py
maint/gemm_v2/correctness_evaluation_sm70.py
+14
-11
maint/gemm_v2/correctness_evaluation_tcgen05.py
maint/gemm_v2/correctness_evaluation_tcgen05.py
+12
-17
maint/gemm_v2/latency.py
maint/gemm_v2/latency.py
+3
-4
maint/gemm_v2/latency_gemm.py
maint/gemm_v2/latency_gemm.py
+3
-4
maint/gemm_v2/latency_mha_fwd_bhsd.py
maint/gemm_v2/latency_mha_fwd_bhsd.py
+40
-58
maint/host_checks/01_num_args_mismatch.py
maint/host_checks/01_num_args_mismatch.py
+1
-0
maint/host_checks/02_pointer_type_error.py
maint/host_checks/02_pointer_type_error.py
+1
-0
maint/host_checks/03_ndim_mismatch.py
maint/host_checks/03_ndim_mismatch.py
+2
-2
maint/host_checks/04_dtype_mismatch.py
maint/host_checks/04_dtype_mismatch.py
+2
-2
maint/host_checks/05_shape_mismatch.py
maint/host_checks/05_shape_mismatch.py
+2
-2
maint/host_checks/06_strides_mismatch.py
maint/host_checks/06_strides_mismatch.py
+2
-2
maint/host_checks/07_device_type_mismatch.py
maint/host_checks/07_device_type_mismatch.py
+2
-2
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/warp_specialize/example_warp_specialize_flashmla.py
View file @
29051439
...
...
@@ -9,7 +9,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
])
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
...
...
@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
# smem_sQ
...
...
@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
})
T
.
annotate_layout
(
{
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
}
)
# barriers_Q
q_shared_ready_barrier
=
T
.
alloc_barrier
(
arrive_count
=
256
)
...
...
@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
barrier_arrive
(
q_shared_ready_barrier
)
T
.
barrier_wait
(
q_shared_ready_barrier
,
0
)
...
...
@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
fill
(
acc_o_l
,
0
)
T
.
fill
(
logsum_0
,
0
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
T
.
barrier_wait
(
kv_shared_0_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_0_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s_0
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s_0
[
i
,
j
]
=
T
.
exp2
(
acc_s_0
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s_0
,
scores_sum_0
,
dim
=
1
)
...
...
@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_wait
(
scale_1_ready_barrier
,
k
%
2
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
barrier_arrive
(
kv_shared_0_l_is_ready
)
# Step 11.
...
...
@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP1_shared
,
KV_shared_1_l
,
acc_o_l
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
copy
(
logsum_0
,
logsum
)
...
...
@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_l
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
else
:
T
.
copy
(
Q_pe_shared
,
Q_pe_local_1
)
...
...
@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
# Step 2.
T
.
barrier_wait
(
kv_shared_1_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_1_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s_1
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
scores_max_1
,
scores_max
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
# Step 8.
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
i
,
j
]
=
acc_o_r
[
i
,
j
]
*
(
scores_scale_0
[
i
]
*
scores_scale_1
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
T
.
barrier_arrive
(
scale_1_ready_barrier
)
...
...
@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
s_shared_ready_barrier
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_wait
(
p0_1_1_ready_barrier
,
k
%
2
)
...
...
@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP0_shared
,
KV_shared_0_r
,
acc_o_r
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
barrier_arrive
(
kv_shared_0_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_wait
(
lse_0_ready_barrier
,
0
)
...
...
@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_r
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ tilelang.disable_cache()
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
num_stages
=
2
mbarrier_list
=
[
128
,
128
]
*
num_stages
...
...
@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for
ko
in
range
(
T
.
ceildiv
(
K
,
block_K
)):
with
T
.
ws
(
1
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:(
by
+
1
)
*
block_M
,
ko
*
block_K
:(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:(
ko
+
1
)
*
block_K
,
bx
*
block_N
:(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:
(
by
+
1
)
*
block_M
,
ko
*
block_K
:
(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:
(
ko
+
1
)
*
block_K
,
bx
*
block_N
:
(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
)
with
T
.
ws
(
0
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
+
num_stages
)
with
T
.
ws
(
0
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
View file @
29051439
...
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
View file @
29051439
...
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
View file @
29051439
...
...
@@ -5,26 +5,20 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
},
)
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
warp_group_num
=
2
threads
=
128
*
warp_group_num
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ import tilelang.language as T
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,
K
),
dtype
],
...
...
format.sh
View file @
29051439
...
...
@@ -9,7 +9,7 @@
# bash format.sh --all
#
#
#
YAPF
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
#
Ruff (format)
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
...
...
maint/gemm_v2/correctness_evaluation.py
View file @
29051439
...
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -66,7 +66,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -151,9 +152,9 @@ def matmul_rs(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -238,9 +239,9 @@ def matmul_sr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -326,9 +327,9 @@ def matmul_rr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"int8"
,
"int32"
,
"int32"
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
FALSE_TRUE_CASES
=
(
[
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"int8"
,
"int32"
,
"int32"
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
)
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_sm70.py
View file @
29051439
...
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -67,7 +67,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -150,9 +151,9 @@ def matmul_rs(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -213,14 +214,15 @@ def run_gemm_rs(
M_VALUES
=
[
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
...
...
@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
"float16"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
)
for
k
in
K_VALUES
])
)
for
k
in
K_VALUES
]
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_tcgen05.py
View file @
29051439
...
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -42,15 +42,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES
=
[
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"float16"
,
"float32"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float-float"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
...
...
@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
)
for
k
in
K_VALUES_8Bit
]
TRANS_CASES
=
[
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
...
...
maint/gemm_v2/latency.py
View file @
29051439
...
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_gemm.py
View file @
29051439
...
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_mha_fwd_bhsd.py
View file @
29051439
...
...
@@ -8,13 +8,13 @@ import argparse
from
functools
import
partial
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
16
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
1024
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
1024
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
256
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
16
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
1024
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
1024
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
256
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
...
...
@@ -29,20 +29,13 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
...
...
@@ -62,7 +55,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
...
...
@@ -85,7 +78,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if
use_v2
:
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -94,13 +87,13 @@ def flashattn(batch,
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -125,18 +118,18 @@ def flashattn(batch,
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
...
@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_
M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -206,18 +198,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
print
(
kernel
.
get_kernel_source
())
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
...
...
maint/host_checks/01_num_args_mismatch.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/02_pointer_type_error.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/03_ndim_mismatch.py
View file @
29051439
"""Reproduce: ndim (rank) mismatch for A.
"""
"""Reproduce: ndim (rank) mismatch for A.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/04_dtype_mismatch.py
View file @
29051439
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/05_shape_mismatch.py
View file @
29051439
"""Reproduce: shape constant/symbol mismatch on A.
"""
"""Reproduce: shape constant/symbol mismatch on A.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/06_strides_mismatch.py
View file @
29051439
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/07_device_type_mismatch.py
View file @
29051439
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment