Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
291 deletions
+215
-291
examples/warp_specialize/example_warp_specialize_flashmla.py
examples/warp_specialize/example_warp_specialize_flashmla.py
+53
-91
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
...ialize/example_warp_specialize_gemm_barrierpipe_stage2.py
+5
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
..._specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
..._specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
..._specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
+8
-14
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
...pecialize/example_warp_specialize_gemm_softpipe_stage2.py
+0
-1
format.sh
format.sh
+1
-1
maint/gemm_v2/correctness_evaluation.py
maint/gemm_v2/correctness_evaluation.py
+56
-44
maint/gemm_v2/correctness_evaluation_sm70.py
maint/gemm_v2/correctness_evaluation_sm70.py
+14
-11
maint/gemm_v2/correctness_evaluation_tcgen05.py
maint/gemm_v2/correctness_evaluation_tcgen05.py
+12
-17
maint/gemm_v2/latency.py
maint/gemm_v2/latency.py
+3
-4
maint/gemm_v2/latency_gemm.py
maint/gemm_v2/latency_gemm.py
+3
-4
maint/gemm_v2/latency_mha_fwd_bhsd.py
maint/gemm_v2/latency_mha_fwd_bhsd.py
+40
-58
maint/host_checks/01_num_args_mismatch.py
maint/host_checks/01_num_args_mismatch.py
+1
-0
maint/host_checks/02_pointer_type_error.py
maint/host_checks/02_pointer_type_error.py
+1
-0
maint/host_checks/03_ndim_mismatch.py
maint/host_checks/03_ndim_mismatch.py
+2
-2
maint/host_checks/04_dtype_mismatch.py
maint/host_checks/04_dtype_mismatch.py
+2
-2
maint/host_checks/05_shape_mismatch.py
maint/host_checks/05_shape_mismatch.py
+2
-2
maint/host_checks/06_strides_mismatch.py
maint/host_checks/06_strides_mismatch.py
+2
-2
maint/host_checks/07_device_type_mismatch.py
maint/host_checks/07_device_type_mismatch.py
+2
-2
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/warp_specialize/example_warp_specialize_flashmla.py
View file @
29051439
...
@@ -9,7 +9,7 @@ import argparse
...
@@ -9,7 +9,7 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
])
@
tilelang
.
jit
(
out_idx
=
[
6
])
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
kv_group_num
=
heads
//
kv_head_num
...
@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
@
T
.
macro
def
flash_attn
(
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
# smem_sQ
# smem_sQ
...
@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
T
.
annotate_layout
(
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
{
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
})
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
}
)
# barriers_Q
# barriers_Q
q_shared_ready_barrier
=
T
.
alloc_barrier
(
arrive_count
=
256
)
q_shared_ready_barrier
=
T
.
alloc_barrier
(
arrive_count
=
256
)
...
@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
barrier_arrive
(
q_shared_ready_barrier
)
T
.
barrier_arrive
(
q_shared_ready_barrier
)
T
.
barrier_wait
(
q_shared_ready_barrier
,
0
)
T
.
barrier_wait
(
q_shared_ready_barrier
,
0
)
...
@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
fill
(
acc_o_l
,
0
)
T
.
fill
(
acc_o_l
,
0
)
T
.
fill
(
logsum_0
,
0
)
T
.
fill
(
logsum_0
,
0
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
for
k
in
T
.
serial
(
loop_range
):
T
.
barrier_wait
(
kv_shared_0_l_is_ready
,
k
%
2
)
T
.
barrier_wait
(
kv_shared_0_l_is_ready
,
k
%
2
)
T
.
gemm
(
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_0_r_is_ready
,
k
%
2
)
T
.
barrier_wait
(
kv_shared_0_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s_0
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s_0
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s_0
[
i
,
j
]
=
T
.
exp2
(
acc_s_0
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s_0
[
i
,
j
]
=
T
.
exp2
(
acc_s_0
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s_0
,
scores_sum_0
,
dim
=
1
)
T
.
reduce_sum
(
acc_s_0
,
scores_sum_0
,
dim
=
1
)
...
@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_wait
(
scale_1_ready_barrier
,
k
%
2
)
T
.
barrier_wait
(
scale_1_ready_barrier
,
k
%
2
)
if
k
<
loop_range
-
1
:
if
k
<
loop_range
-
1
:
T
.
copy
(
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
barrier_arrive
(
kv_shared_0_l_is_ready
)
T
.
barrier_arrive
(
kv_shared_0_l_is_ready
)
# Step 11.
# Step 11.
...
@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP1_shared
,
KV_shared_1_l
,
acc_o_l
)
T
.
gemm
(
SP1_shared
,
KV_shared_1_l
,
acc_o_l
)
if
k
<
loop_range
-
1
:
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
copy
(
logsum_0
,
logsum
)
T
.
copy
(
logsum_0
,
logsum
)
...
@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_l
[
i
,
j
]
/=
logsum
[
i
]
acc_o_l
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
bid
,
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
else
:
else
:
T
.
copy
(
Q_pe_shared
,
Q_pe_local_1
)
T
.
copy
(
Q_pe_shared
,
Q_pe_local_1
)
...
@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
for
k
in
T
.
serial
(
loop_range
):
# Step 2.
# Step 2.
T
.
barrier_wait
(
kv_shared_1_l_is_ready
,
k
%
2
)
T
.
barrier_wait
(
kv_shared_1_l_is_ready
,
k
%
2
)
T
.
gemm
(
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_1_r_is_ready
,
k
%
2
)
T
.
barrier_wait
(
kv_shared_1_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s_1
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s_1
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
scores_max_1
,
scores_max
)
T
.
copy
(
scores_max_1
,
scores_max
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
# Step 8.
# Step 8.
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
i
,
j
]
=
acc_o_r
[
i
,
j
]
*
(
scores_scale_0
[
i
]
*
scores_scale_1
[
i
])
acc_o_r
[
i
,
j
]
=
acc_o_r
[
i
,
j
]
*
(
scores_scale_0
[
i
]
*
scores_scale_1
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
i
]
+
scores_sum_1
[
i
]
T
.
barrier_arrive
(
scale_1_ready_barrier
)
T
.
barrier_arrive
(
scale_1_ready_barrier
)
...
@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
s_shared_ready_barrier
)
T
.
barrier_arrive
(
s_shared_ready_barrier
)
if
k
<
loop_range
-
1
:
if
k
<
loop_range
-
1
:
T
.
copy
(
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_wait
(
p0_1_1_ready_barrier
,
k
%
2
)
T
.
barrier_wait
(
p0_1_1_ready_barrier
,
k
%
2
)
...
@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP0_shared
,
KV_shared_0_r
,
acc_o_r
)
T
.
gemm
(
SP0_shared
,
KV_shared_0_r
,
acc_o_r
)
if
k
<
loop_range
-
1
:
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
barrier_arrive
(
kv_shared_0_r_is_ready
)
T
.
barrier_arrive
(
kv_shared_0_r_is_ready
)
T
.
copy
(
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_wait
(
lse_0_ready_barrier
,
0
)
T
.
barrier_wait
(
lse_0_ready_barrier
,
0
)
...
@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_r
[
i
,
j
]
/=
logsum
[
i
]
acc_o_r
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
h_dim
:])
@
T
.
prim_func
@
T
.
prim_func
def
main_no_split
(
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
...
@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
...
@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
...
@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
View file @
29051439
...
@@ -8,7 +8,6 @@ tilelang.disable_cache()
...
@@ -8,7 +8,6 @@ tilelang.disable_cache()
# @tilelang.jit
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
num_stages
=
2
num_stages
=
2
mbarrier_list
=
[
128
,
128
]
*
num_stages
mbarrier_list
=
[
128
,
128
]
*
num_stages
...
@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
...
@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for
ko
in
range
(
T
.
ceildiv
(
K
,
block_K
)):
for
ko
in
range
(
T
.
ceildiv
(
K
,
block_K
)):
with
T
.
ws
(
1
):
with
T
.
ws
(
1
):
T
.
mbarrier_wait_parity
(
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
mbarrier
=
ko
%
num_stages
+
num_stages
,
T
.
copy
(
A
[
by
*
block_M
:
(
by
+
1
)
*
block_M
,
ko
*
block_K
:
(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
B
[
ko
*
block_K
:
(
ko
+
1
)
*
block_K
,
bx
*
block_N
:
(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
A
[
by
*
block_M
:(
by
+
1
)
*
block_M
,
ko
*
block_K
:(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:(
ko
+
1
)
*
block_K
,
bx
*
block_N
:(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
)
with
T
.
ws
(
0
):
with
T
.
ws
(
0
):
T
.
mbarrier_wait_parity
(
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
+
num_stages
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
+
num_stages
)
with
T
.
ws
(
0
):
with
T
.
ws
(
0
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
View file @
29051439
...
@@ -5,20 +5,12 @@ import tilelang.language as T
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
View file @
29051439
...
@@ -5,20 +5,12 @@ import tilelang.language as T
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
View file @
29051439
...
@@ -5,26 +5,20 @@ import tilelang.language as T
...
@@ -5,26 +5,20 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
},
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
)
N
,
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
warp_group_num
=
2
warp_group_num
=
2
threads
=
128
*
warp_group_num
threads
=
128
*
warp_group_num
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
View file @
29051439
...
@@ -6,7 +6,6 @@ import tilelang.language as T
...
@@ -6,7 +6,6 @@ import tilelang.language as T
# @tilelang.jit
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
[(
M
,
K
),
dtype
],
A
:
T
.
Tensor
[(
M
,
K
),
dtype
],
...
...
format.sh
View file @
29051439
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
# bash format.sh --all
# bash format.sh --all
#
#
#
#
#
YAPF
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
#
Ruff (format)
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
# Cause the script to exit if a single command fails
...
...
maint/gemm_v2/correctness_evaluation.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -66,7 +66,8 @@ def _compile_and_check(
...
@@ -66,7 +66,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
...
@@ -151,9 +152,9 @@ def matmul_rs(
...
@@ -151,9 +152,9 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -238,9 +239,9 @@ def matmul_sr(
...
@@ -238,9 +239,9 @@ def matmul_sr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -326,9 +327,9 @@ def matmul_rr(
...
@@ -326,9 +327,9 @@ def matmul_rr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
...
@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
FALSE_TRUE_CASES
=
(
pytest
.
param
(
[
k
,
pytest
.
param
(
"float16"
,
k
,
"float16"
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
"float16"
,
)
for
k
in
K_VALUES
id
=
f
"K
{
k
}
-float16-float16-float16"
,
]
+
[
pytest
.
param
(
)
k
,
for
k
in
K_VALUES
"int8"
,
]
"int32"
,
+
[
"int32"
,
pytest
.
param
(
id
=
"K32-int8-int32-int32"
,
k
,
)
for
k
in
K_VALUES_8Bit
]
+
[
"int8"
,
pytest
.
param
(
"int32"
,
k
,
"int32"
,
"float8_e5m2"
,
id
=
"K32-int8-int32-int32"
,
"float32"
,
)
"float32"
,
for
k
in
K_VALUES_8Bit
id
=
"K32-float8_e5m2-float32-float32"
,
]
)
for
k
in
K_VALUES_8Bit
+
[
]
+
[
pytest
.
param
(
pytest
.
param
(
k
,
k
,
"float8_e5m2"
,
"float8_e4m3"
,
"float32"
,
"float32"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
)
for
k
in
K_VALUES_8Bit
for
k
in
K_VALUES_8Bit
])
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
)
def
_ensure_torch_dtypes
(
*
dtype_names
):
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_sm70.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -67,7 +67,8 @@ def _compile_and_check(
...
@@ -67,7 +67,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
...
@@ -150,9 +151,9 @@ def matmul_rs(
...
@@ -150,9 +151,9 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
@@ -213,14 +214,15 @@ def run_gemm_rs(
...
@@ -213,14 +214,15 @@ def run_gemm_rs(
M_VALUES
=
[
64
,
128
]
M_VALUES
=
[
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
]
K_VALUES
=
[
16
,
32
,
64
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
pytest
.
param
(
k
,
k
,
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
]
+
[
pytest
.
param
(
pytest
.
param
(
k
,
k
,
...
@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
...
@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
"float16"
,
"float16"
,
"float32"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
)
for
k
in
K_VALUES
)
])
for
k
in
K_VALUES
]
def
_ensure_torch_dtypes
(
*
dtype_names
):
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
maint/gemm_v2/correctness_evaluation_tcgen05.py
View file @
29051439
...
@@ -27,9 +27,9 @@ def matmul(
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -42,15 +42,7 @@ def matmul(
...
@@ -42,15 +42,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
T
.
copy
(
C_tmem
,
C_local
)
...
@@ -74,7 +66,8 @@ def _compile_and_check(
...
@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
...
@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
...
@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES
=
[
64
,
128
,
256
,
512
]
N_VALUES
=
[
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
pytest
.
param
(
k
,
k
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
"float32"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float-float"
,
id
=
f
"K
{
k
}
-float16-float-float"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
]
+
[
pytest
.
param
(
pytest
.
param
(
k
,
k
,
...
@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
...
@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
"float32"
,
"float32"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
)
])
for
k
in
K_VALUES_8Bit
]
TRANS_CASES
=
[
TRANS_CASES
=
[
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
...
...
maint/gemm_v2/latency.py
View file @
29051439
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
matmul_relu_kernel
(
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_gemm.py
View file @
29051439
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
...
@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
matmul_relu_kernel
(
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_mha_fwd_bhsd.py
View file @
29051439
...
@@ -8,13 +8,13 @@ import argparse
...
@@ -8,13 +8,13 @@ import argparse
from
functools
import
partial
from
functools
import
partial
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
16
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
16
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
1024
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
1024
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
1024
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
1024
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
256
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
256
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -29,20 +29,13 @@ def get_configs():
...
@@ -29,20 +29,13 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
def
flashattn
(
batch
,
)
heads
,
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
seq_q
,
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
dtype
=
"float16"
...
@@ -62,7 +55,7 @@ def flashattn(batch,
...
@@ -62,7 +55,7 @@ def flashattn(batch,
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
q_idx
=
bx
*
block_M
+
i
+
past_len
...
@@ -85,7 +78,7 @@ def flashattn(batch,
...
@@ -85,7 +78,7 @@ def flashattn(batch,
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if
use_v2
:
if
use_v2
:
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -94,13 +87,13 @@ def flashattn(batch,
...
@@ -94,13 +87,13 @@ def flashattn(batch,
@
T
.
macro
@
T
.
macro
def
Softmax
(
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -125,18 +118,18 @@ def flashattn(batch,
...
@@ -125,18 +118,18 @@ def flashattn(batch,
@
T
.
macro
@
T
.
macro
def
Rescale
(
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
@@ -152,43 +145,42 @@ def flashattn(batch,
...
@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
loop_range
=
(
T
.
min
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
if
is_causal
(
bx
+
1
)
*
block_
M
+
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
return
output
...
@@ -206,18 +198,8 @@ def main(
...
@@ -206,18 +198,8 @@ def main(
if
is_causal
:
if
is_causal
:
total_flops
*=
0.5
total_flops
*=
0.5
if
(
not
tune
):
if
not
tune
:
kernel
=
flashattn
(
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
...
...
maint/host_checks/01_num_args_mismatch.py
View file @
29051439
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/02_pointer_type_error.py
View file @
29051439
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/03_ndim_mismatch.py
View file @
29051439
"""Reproduce: ndim (rank) mismatch for A.
"""Reproduce: ndim (rank) mismatch for A.
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/04_dtype_mismatch.py
View file @
29051439
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/05_shape_mismatch.py
View file @
29051439
"""Reproduce: shape constant/symbol mismatch on A.
"""Reproduce: shape constant/symbol mismatch on A.
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/06_strides_mismatch.py
View file @
29051439
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/07_device_type_mismatch.py
View file @
29051439
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment