Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
a9d823b8
Unverified
Commit
a9d823b8
authored
Nov 05, 2025
by
Yu Cheng
Committed by
GitHub
Nov 05, 2025
Browse files
[Example] Update GQA varlen fwd (#1173)
* [Example] Update GQA varlen fwd * fix
parent
298ab480
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
41 deletions
+53
-41
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+53
-41
No files found.
examples/flash_attention/example_gqa_fwd_varlen.py
View file @
a9d823b8
...
@@ -24,21 +24,32 @@ def attention_ref(
...
@@ -24,21 +24,32 @@ def attention_ref(
dtype_og
=
q
.
dtype
dtype_og
=
q
.
dtype
if
upcast
:
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
b
,
T
,
Hq
,
D
=
q
.
shape
scale
=
(
1.0
/
dim
)
**
0.5
S
=
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
scale
=
(
1.0
/
D
)
**
0.5
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
left
,
right
=
window_size
left
=
S
if
left
is
None
or
left
<
0
else
int
(
left
)
right
=
S
if
right
is
None
or
right
<
0
else
int
(
right
)
t_idx
=
torch
.
arange
(
T
,
device
=
scores
.
device
)[:,
None
]
s_idx
=
torch
.
arange
(
S
,
device
=
scores
.
device
)[
None
,
:]
visible_ts
=
(
s_idx
>=
(
t_idx
-
left
))
&
(
s_idx
<=
(
t_idx
+
right
))
visible_mask
=
visible_ts
.
unsqueeze
(
0
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
k_keep
=
rearrange
(
key_padding_mask
,
"b s -> b 1 1 s"
)
visible_mask
=
visible_mask
&
k_keep
neg_inf
=
torch
.
finfo
(
scores
.
dtype
).
min
scores
=
scores
*
scale
scores
=
scores
*
scale
scores
=
scores
.
masked_fill
(
~
visible_mask
,
neg_inf
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
q_keep
=
rearrange
(
query_padding_mask
,
"b t -> b 1 t 1"
)
attention
=
attention
.
masked_fill
(
~
q_keep
,
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
output
.
masked_fill
_
(
rearrange
(
~
query_padding_mask
,
"b
s
-> b
s
1 1"
),
0.0
)
output
=
output
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b
t
-> b
t
1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
...
@@ -91,53 +102,53 @@ def flashattn(batch_size,
...
@@ -91,53 +102,53 @@ def flashattn(batch_size,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
})
batch_idx
=
bz
batch_idx
=
bz
head_idx
=
by
head_idx
=
by
kv_head_idx
=
head_idx
//
groups
kv_head_idx
=
head_idx
//
groups
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
kv_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
kv_current_seqlen
=
k_end_idx
-
kv_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
T
.
copy
(
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
Q_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
>=
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
q_current_seqlen
+
(
bx
+
1
)
*
block_M
,
block_N
),
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K_unpad
[
k_start_idx
+
k
*
block_N
:
k_start_idx
+
(
k
+
1
)
*
block_N
,
K_unpad
[
k
v
_start_idx
+
k
*
block_N
:
k
v
_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
kv_head_idx
,
:],
K_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
k_current_seqlen
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
<
k
*
block_N
+
j
)
or
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
else
:
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
k
*
block_N
+
j
>=
k
v
_current_seqlen
),
-
1e9
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -145,6 +156,9 @@ def flashattn(batch_size,
...
@@ -145,6 +156,9 @@ def flashattn(batch_size,
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -158,11 +172,8 @@ def flashattn(batch_size,
...
@@ -158,11 +172,8 @@ def flashattn(batch_size,
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
copy
(
V_unpad
[
v_start_idx
+
k
*
block_N
:
v_start_idx
+
(
k
+
1
)
*
block_N
,
V_unpad
[
k
v_start_idx
+
k
*
block_N
:
k
v_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
kv_head_idx
,
:],
V_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
v_current_seqlen
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -191,8 +202,7 @@ def main(batch: int = 1,
...
@@ -191,8 +202,7 @@ def main(batch: int = 1,
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
is_causal
:
if
causal
:
total_flops
*=
0.5
total_flops
*=
0.5
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
...
@@ -201,9 +211,9 @@ def main(batch: int = 1,
...
@@ -201,9 +211,9 @@ def main(batch: int = 1,
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
head_kv
=
heads
//
groups
head_kv
=
heads
//
groups
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
...
@@ -236,10 +246,10 @@ def main(batch: int = 1,
...
@@ -236,10 +246,10 @@ def main(batch: int = 1,
heads
,
heads
,
dim
,
dim
,
is_causal
,
is_causal
,
block_M
=
64
,
block_M
=
128
,
block_N
=
64
,
block_N
=
128
,
num_stages
=
1
,
num_stages
=
2
,
threads
=
128
)
threads
=
256
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
...
@@ -255,7 +265,9 @@ def main(batch: int = 1,
...
@@ -255,7 +265,9 @@ def main(batch: int = 1,
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
))
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
),
_n_warmup
=
5
,
_n_repeat
=
5
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment