Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
013adca0
Unverified
Commit
013adca0
authored
Sep 05, 2025
by
Wenhao Xie
Committed by
GitHub
Sep 05, 2025
Browse files
[Bugfix] Fix incorrect synchronization bug in minference example (#786)
* fix * lint
parent
e5b61e9b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
93 deletions
+117
-93
examples/minference/example_vertical_slash_sparse_attn.py
examples/minference/example_vertical_slash_sparse_attn.py
+117
-93
No files found.
examples/minference/example_vertical_slash_sparse_attn.py
View file @
013adca0
...
...
@@ -10,9 +10,7 @@ import triton.language as tl
import
tilelang
import
tilelang.language
as
T
from
tilelang.profiler
import
do_bench
from
tilelang.testing
import
torch_assert_close
tilelang
.
disable_cache
()
...
...
@@ -27,7 +25,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
count_shape
=
[
batch
,
heads
,
(
seq_len
+
block_M
-
1
)
//
block_M
]
seq_blocks
=
(
seq_len
+
block_M
-
1
)
//
block_M
count_shape
=
[
batch
,
heads
,
seq_blocks
]
offset_shape
=
count_shape
+
[
slash_size
]
index_shape
=
count_shape
+
[
vertical_size
]
...
...
@@ -47,7 +47,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
V
:
T
.
Tensor
(
shape
,
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
column_index
:
T
.
SharedBuffer
([
vertical_size
],
int_dtype
),
column_index
:
T
.
SharedBuffer
([
vertical_size
_round
],
int_dtype
),
column_count
:
T
.
int32
,
k
:
T
.
int32
,
bz
:
T
.
int32
,
...
...
@@ -80,8 +80,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
count
:
T
.
int32
,
):
T
.
ptx_wait_group
(
1
)
T
.
ptx_wait_group
(
count
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
+
j
<
column_count
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -106,7 +107,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
@
T
.
prim_func
def
vs_sparse_flashattn
(
def
vs_sparse_flashattn
_ws
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -116,13 +117,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
ColumnCount
:
T
.
Tensor
(
count_shape
,
int_dtype
),
ColumnIndex
:
T
.
Tensor
(
index_shape
,
int_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bc
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bc
,
by
,
bz
):
bx
=
T
.
ceildiv
(
seq_len
,
block_M
)
-
1
-
bc
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
2
,
block_N
,
dim
],
dtype
)
K_shared_1
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared_1
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared_2
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared_2
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
...
...
@@ -137,10 +141,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
column_count
=
T
.
alloc_local
([
1
],
int_dtype
)
column_index
=
T
.
alloc_shared
([
vertical_size_round
],
int_dtype
,
scope
=
"shared"
)
K_shared_1
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared_1
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared_2
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared_2
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
T
.
create_list_of_mbarrier
([
128
]
*
9
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
block_count
[
0
]
=
BlockCount
[
bz
,
by
,
bx
]
column_count
[
0
]
=
ColumnCount
[
bz
,
by
,
bx
]
...
...
@@ -153,81 +158,103 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if
vi
<
vertical_size
:
column_index
[
vi
]
=
ColumnIndex
[
bz
,
by
,
bx
,
vi
]
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
for
bi
in
T
.
Pipelined
(
block_count
[
0
],
num_stages
=
num_stages
):
k
=
block_offset
[
bi
]
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
if
column_count
[
0
]
!=
0
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
k
=
bi
*
block_N
if
bi
%
2
==
0
:
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
)
tid
=
T
.
get_thread_binding
()
if
tid
>=
128
:
T
.
annotate_producer_reg_dealloc
()
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
mbarrier_arrive
(
mbarrier
=
8
)
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
4
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
K
[
bz
,
by
,
k
:
k
+
block_N
,
:],
K_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
6
,
parity
=
(((
bi
&
3
)
>>
1
)
^
1
))
T
.
copy
(
V
[
bz
,
by
,
k
:
k
+
block_N
,
:],
V_shared
[
bi
%
2
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
2
)
else
:
T
.
annotate_consumer_reg_alloc
()
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
mbarrier_wait_parity
(
mbarrier
=
8
,
parity
=
0
)
for
bi
in
T
.
serial
(
block_count
[
0
]):
k
=
block_offset
[
bi
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
,
parity
=
((
bi
&
3
)
>>
1
))
T
.
gemm
(
Q_shared
,
K_shared
[
bi
%
2
,
:,
:],
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
4
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
=
acc_o
[
i
,
j
]
*
scores_scale
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
mbarrier_wait_parity
(
mbarrier
=
bi
%
2
+
2
,
parity
=
(((
bi
&
3
)
>>
1
)))
T
.
gemm
(
acc_s_cast
,
V_shared
[
bi
%
2
,
:,
:],
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
mbarrier_arrive
(
mbarrier
=
bi
%
2
+
6
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
if
column_count
[
0
]
!=
0
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
0
,
bz
,
by
)
for
bi
in
T
.
serial
(
T
.
ceildiv
(
column_count
[
0
],
block_N
)
-
1
):
k
=
bi
*
block_N
if
bi
%
2
==
0
:
Prefetch
(
K
,
V
,
K_shared_2
,
V_shared_2
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
1
)
else
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
1
)
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
,
0
)
else
:
Prefetch
(
K
,
V
,
K_shared_1
,
V_shared_1
,
column_index
,
column_count
[
0
],
k
+
block_N
,
bz
,
by
)
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
,
0
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
k
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
)
if
T
.
ceildiv
(
column_count
[
0
],
block_N
)
%
2
==
0
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_2
,
V_shared_2
,
scores_scale
,
scores_sum
,
logsum
)
else
:
Compute
(
acc_s
,
acc_s_cast
,
acc_o
,
scores_max
,
scores_max_prev
,
T
.
ceildiv
(
column_count
[
0
],
block_N
)
*
block_N
-
block_N
,
column_count
[
0
],
Q_shared
,
K_shared_1
,
V_shared_1
,
scores_scale
,
scores_sum
,
logsum
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
return
vs_sparse_flashattn
return
vs_sparse_flashattn_ws
return
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
)
...
...
@@ -466,7 +493,7 @@ def vertical_slash_sparse_attention(
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
((
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
seqlens
=
torch
.
tensor
([
context_size
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
seqlens
=
torch
.
tensor
([
context_size
]
*
query
.
shape
[
0
]
,
dtype
=
torch
.
int32
,
device
=
query
.
device
)
sm_scale
=
head_dim
**-
0.5
block_count
,
block_offset
,
column_count
,
column_index
=
convert_vertical_slash_indexes
(
seqlens
,
...
...
@@ -524,7 +551,6 @@ def main(argv=None):
parser
.
add_argument
(
"--slash_size"
,
type
=
int
,
default
=
200
)
args
=
parser
.
parse_args
(
argv
)
# vs_list = [[1000, 200], [1000, 600], [800, 600]]
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
=
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
head_dim
...
...
@@ -555,12 +581,10 @@ def main(argv=None):
_attn
=
vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_topk
,
slash
)
triton_out
=
_attn
(
True
)
tilelang_out
=
_attn
(
False
)
triton_out
=
_attn
(
True
)
torch_assert_close
(
triton_out
,
tilelang_out
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.0
)
print
(
"Pass topk sparse attention test with qlen == klen"
)
torch
.
testing
.
assert_close
(
triton_out
,
tilelang_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
triton_time
=
do_bench
(
lambda
:
_attn
(
True
))
tilelang_time
=
do_bench
(
lambda
:
_attn
(
False
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment