Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d66b83c9
Unverified
Commit
d66b83c9
authored
Oct 20, 2025
by
Yu Cheng
Committed by
GitHub
Oct 20, 2025
Browse files
[Example] Update GQA varlen fwd and MHA varlen fwd (#1071)
parent
e57ef582
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
405 additions
and
158 deletions
+405
-158
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+276
-0
examples/flash_attention/example_mha_fwd_varlen.py
examples/flash_attention/example_mha_fwd_varlen.py
+7
-158
examples/flash_attention/varlen_utils.py
examples/flash_attention/varlen_utils.py
+122
-0
No files found.
examples/flash_attention/example_gqa_fwd_varlen.py
0 → 100644
View file @
d66b83c9
# ruff: noqa
import
argparse
import
torch
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
from
einops
import
rearrange
,
repeat
from
tilelang.profiler
import
do_bench
from
varlen_utils
import
generate_random_padding_mask
,
generate_qkv
tilelang
.
disable_cache
()
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
upcast
=
True
,
):
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
=
scores
*
scale
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
UQ
,
heads
,
dim
]
kv_shape
=
[
UKV
,
head_kv
,
dim
]
o_shape
=
[
UQ
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
Q_unpad
:
T
.
Tensor
(
q_shape
,
dtype
),
K_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
batch_idx
=
bz
head_idx
=
by
kv_head_idx
=
head_idx
//
groups
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
>=
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K_unpad
[
k_start_idx
+
k
*
block_N
:
k_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
k_current_seqlen
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V_unpad
[
v_start_idx
+
k
*
block_N
:
v_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
v_current_seqlen
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Output_unpad
[
q_start_idx
+
bx
*
block_M
+
i
,
head_idx
,
d
]
=
O_shared
[
i
,
d
]
return
main
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
q_seqlen
:
int
=
2048
,
k_seqlen
:
int
=
2048
,
dim
:
int
=
128
,
groups
:
int
=
16
,
is_causal
:
bool
=
False
):
assert
heads
%
groups
==
0
,
"heads must be divisible by groups"
flops_per_matmul
=
2.0
*
batch
*
heads
*
q_seqlen
*
k_seqlen
*
dim
total_flops
=
2
*
flops_per_matmul
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
causal
:
total_flops
*=
0.5
tilelang
.
testing
.
set_random_seed
(
0
)
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
head_kv
=
heads
//
groups
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
_
,
_
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
UKV
=
k_unpad
.
shape
[
0
]
kernel
=
flashattn
(
batch
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
_
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
query_padding_mask
,
key_padding_mask
=
key_padding_mask
,
causal
=
is_causal
,
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'query heads'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--q_seqlen'
,
type
=
int
,
default
=
2048
,
help
=
'query sequence length'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
2048
,
help
=
'key/value sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'head dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal attention'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
q_seqlen
,
args
.
k_seqlen
,
args
.
dim
,
args
.
groups
,
args
.
is_causal
)
examples/flash_attention/example_mha_fwd_varlen.py
View file @
d66b83c9
...
@@ -7,158 +7,7 @@ import argparse
...
@@ -7,158 +7,7 @@ import argparse
import
torch
import
torch
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
bert_padding
import
pad_input
,
unpad_input
from
varlen_utils
import
generate_random_padding_mask
,
generate_qkv
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
key_leftpad
=
None
,
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
if
key_leftpad
is
not
None
:
key_leftpad
=
rearrange
(
key_leftpad
,
"b -> b 1 1 1"
)
col_idx
=
repeat
(
col_idx
,
"s -> b 1 1 s"
,
b
=
key_leftpad
.
shape
[
0
])
col_idx
=
torch
.
where
(
col_idx
>=
key_leftpad
,
col_idx
-
key_leftpad
,
2
**
32
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
))
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
))
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
def
attention_ref
(
...
@@ -359,7 +208,7 @@ def flashattn(batch_size,
...
@@ -359,7 +208,7 @@ def flashattn(batch_size,
return
main
return
main
def
main
(
batch
:
int
=
2
,
heads
:
int
=
1
6
,
seq_len
:
int
=
2
56
,
dim
:
int
=
32
):
def
main
(
batch
:
int
=
8
,
heads
:
int
=
6
4
,
seq_len
:
int
=
2
048
,
dim
:
int
=
128
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -431,15 +280,15 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32):
...
@@ -431,15 +280,15 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32):
fla_out
=
output_pad_fn
(
fla_out_unpad
)
fla_out
=
output_pad_fn
(
fla_out_unpad
)
torch
.
testing
.
assert_close
(
out
,
fla_out
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
fla_out
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"A
ssert Equal P
assed"
)
print
(
"A
ll checks p
assed
.✅
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
2
,
help
=
'batch size'
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
1
6
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
6
4
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
2
56
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
2
048
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
32
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
)
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
)
examples/flash_attention/varlen_utils.py
0 → 100644
View file @
d66b83c9
# ruff: noqa
import
torch
from
einops
import
rearrange
,
repeat
from
bert_padding
import
pad_input
,
unpad_input
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment