Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
343
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1334 additions
and
1635 deletions
+1334
-1635
examples/elementwise/test_example_elementwise.py
examples/elementwise/test_example_elementwise.py
+4
-0
examples/flash_attention/README.md
examples/flash_attention/README.md
+3
-1
examples/flash_attention/bert_padding.py
examples/flash_attention/bert_padding.py
+4
-12
examples/flash_attention/example_gqa_bwd.py
examples/flash_attention/example_gqa_bwd.py
+154
-190
examples/flash_attention/example_gqa_bwd_tma_reduce.py
examples/flash_attention/example_gqa_bwd_tma_reduce.py
+170
-209
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
...ples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
+208
-276
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
+98
-133
examples/flash_attention/example_gqa_fwd_bshd.py
examples/flash_attention/example_gqa_fwd_bshd.py
+59
-80
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
...s/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
+52
-60
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+61
-94
examples/flash_attention/example_mha_bwd_bhsd.py
examples/flash_attention/example_mha_bwd_bhsd.py
+92
-85
examples/flash_attention/example_mha_bwd_bshd.py
examples/flash_attention/example_mha_bwd_bshd.py
+89
-82
examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py
...s/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py
+86
-87
examples/flash_attention/example_mha_fwd_bhsd.py
examples/flash_attention/example_mha_fwd_bhsd.py
+49
-61
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
...s/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
+53
-66
examples/flash_attention/example_mha_fwd_bshd.py
examples/flash_attention/example_mha_fwd_bshd.py
+45
-58
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
...s/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
+51
-63
examples/flash_attention/example_mha_fwd_varlen.py
examples/flash_attention/example_mha_fwd_varlen.py
+41
-47
examples/flash_attention/test_example_flash_attention.py
examples/flash_attention/test_example_flash_attention.py
+6
-8
examples/flash_attention/varlen_utils.py
examples/flash_attention/varlen_utils.py
+9
-23
No files found.
Too many changes to show.
To preserve performance only
343 of 343+
files are displayed.
Plain diff
Email patch
examples/elementwise/test_example_elementwise.py
View file @
667632cc
...
...
@@ -6,5 +6,9 @@ def test_example_elementwise_add():
example_elementwise_add
.
main
()
def
test_example_elementwise_add_autotune
():
example_elementwise_add
.
main
(
use_autotune
=
True
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/flash_attention/README.md
View file @
667632cc
...
...
@@ -77,6 +77,8 @@ def flash_attention(
# Compute the maximum value per row on dimension 1 (block_N)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# Compute the factor by which we need to rescale previous partial sums
for
i
in
T
.
Parallel
(
block_M
):
...
...
examples/flash_attention/bert_padding.py
View file @
667632cc
...
...
@@ -6,7 +6,6 @@ from einops import rearrange, repeat
class
IndexFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
...
...
@@ -15,9 +14,7 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim
=
other_shape
.
numel
()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return
torch
.
gather
(
rearrange
(
input
,
"b ... -> b (...)"
),
0
,
repeat
(
indices
,
"z -> z d"
,
d
=
second_dim
)).
reshape
(
-
1
,
*
other_shape
)
return
torch
.
gather
(
rearrange
(
input
,
"b ... -> b (...)"
),
0
,
repeat
(
indices
,
"z -> z d"
,
d
=
second_dim
)).
reshape
(
-
1
,
*
other_shape
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
...
...
@@ -40,14 +37,12 @@ index_first_axis = IndexFirstAxis.apply
class
IndexPutFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
ctx
.
save_for_backward
(
indices
)
assert
indices
.
ndim
==
1
assert
values
.
ndim
>=
2
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output
[
indices
]
=
values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
...
...
@@ -66,7 +61,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class
IndexFirstAxisResidual
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
...
...
@@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
"""
length
=
attention_mask_in_length
.
sum
(
dim
=-
1
)
seqlen
=
attention_mask_in_length
.
size
(
-
1
)
attention_mask_2d
=
torch
.
arange
(
seqlen
,
device
=
length
.
device
,
dtype
=
length
.
dtype
).
expand
(
len
(
length
),
seqlen
)
<
length
.
unsqueeze
(
1
)
attention_mask_2d
=
torch
.
arange
(
seqlen
,
device
=
length
.
device
,
dtype
=
length
.
dtype
).
expand
(
len
(
length
),
seqlen
)
<
length
.
unsqueeze
(
1
)
real_indices_idx
=
torch
.
nonzero
(
attention_mask_in_length
.
flatten
(),
as_tuple
=
False
).
flatten
()
seqlens_in_batch
=
attention_mask_in_length
.
flatten
()[
real_indices_idx
]
indices
=
torch
.
nonzero
(
attention_mask_2d
.
flatten
(),
as_tuple
=
False
).
flatten
()
...
...
examples/flash_attention/example_gqa_bwd.py
View file @
667632cc
...
...
@@ -6,17 +6,19 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -40,25 +42,25 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
...
...
@@ -72,21 +74,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
...
...
@@ -103,29 +107,30 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim_qk
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
blk
=
64
...
...
@@ -137,35 +142,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -197,35 +194,36 @@ def flashattn_bwd_atomic_add(batch,
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -237,37 +235,29 @@ def flashattn_bwd_atomic_add(batch,
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_split
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dk_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -299,37 +289,38 @@ def flashattn_bwd_split(batch,
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -342,16 +333,15 @@ def flashattn_bwd_split(batch,
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
...
...
@@ -369,7 +359,10 @@ class _attention(torch.autograd.Function):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
(
HEAD_KV
,
D_HEAD_V
,
)
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
...
...
@@ -386,17 +379,8 @@ class _attention(torch.autograd.Function):
if
ctx
.
use_atomic
:
kernel
=
flashattn_bwd_atomic_add
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
...
...
@@ -409,17 +393,8 @@ class _attention(torch.autograd.Function):
dv
=
dv
.
to
(
torch
.
float16
)
else
:
kernel
=
flashattn_bwd_split
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
# sum after kernel
shape_v
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
# sum after kernel
...
...
@@ -441,53 +416,45 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
use_atomic
:
bool
=
True
,
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
V
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
dO
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
...
...
@@ -504,7 +471,7 @@ def main(BATCH: int = 1,
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'
All checks passed.✅
'
)
print
(
"
All checks passed.✅
"
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
...
...
@@ -524,17 +491,15 @@ def main(BATCH: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
32
,
help
=
"Number of heads"
)
parser
.
add_argument
(
"--n_ctx"
,
type
=
int
,
default
=
1024
,
help
=
"Context size"
)
parser
.
add_argument
(
"--d_head_qk"
,
type
=
int
,
default
=
192
,
help
=
"Head dimension for Q/K"
)
parser
.
add_argument
(
"--d_head_v"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension for V"
)
parser
.
add_argument
(
"--causal"
,
action
=
"store_true"
,
help
=
"Causal flag"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
16
,
help
=
"groups"
)
parser
.
add_argument
(
"--use_atomic"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use atomic add for dK/dV"
)
parser
.
add_argument
(
"--use_split"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use split for dK/dV"
)
args
=
parser
.
parse_args
()
# Handle backward compatibility and logic
...
...
@@ -546,5 +511,4 @@ if __name__ == "__main__":
# Default: use atomic
use_atomic
=
True
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_tma_reduce.py
View file @
667632cc
...
...
@@ -9,17 +9,19 @@ tilelang.disable_cache()
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -43,27 +45,27 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T
.
fill
(
scores_max
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
...
...
@@ -77,21 +79,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
...
...
@@ -108,12 +112,12 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
...
...
@@ -124,12 +128,14 @@ def make_dq_layout(dQ):
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
head_kv
,
seq_len
,
dim_qk
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
...
...
@@ -146,43 +152,34 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
head_kv
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
})
T
.
copy
(
dK
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
}
)
T
.
copy
(
dK
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -215,37 +212,38 @@ def flashattn_bwd_atomic_add(batch,
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -255,41 +253,31 @@ def flashattn_bwd_atomic_add(batch,
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
,
use_tma
=
True
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
,
use_tma
=
True
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
,
use_tma
=
True
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
,
use_tma
=
True
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
,
use_tma
=
True
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
,
use_tma
=
True
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split_novarlen
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_split_novarlen
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dk_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -321,37 +309,38 @@ def flashattn_bwd_split_novarlen(batch,
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -364,16 +353,15 @@ def flashattn_bwd_split_novarlen(batch,
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
...
...
@@ -391,7 +379,10 @@ class _attention(torch.autograd.Function):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
(
HEAD_KV
,
D_HEAD_V
,
)
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
...
...
@@ -408,17 +399,8 @@ class _attention(torch.autograd.Function):
if
ctx
.
use_atomic
:
kernel
=
flashattn_bwd_atomic_add
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
...
...
@@ -429,17 +411,8 @@ class _attention(torch.autograd.Function):
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
else
:
kernel
=
flashattn_bwd_split_novarlen
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
# sum after kernel
shape_v
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
# sum after kernel
...
...
@@ -447,8 +420,7 @@ class _attention(torch.autograd.Function):
dk
=
torch
.
empty
(
shape_k
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape_v
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dk
,
dv
=
dk
.
sum
(
0
),
dv
.
sum
(
0
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
...
...
@@ -462,53 +434,45 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
use_atomic
:
bool
=
True
,
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
V
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
dO
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
...
...
@@ -525,7 +489,7 @@ def main(BATCH: int = 1,
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'
All checks passed.✅
'
)
print
(
"
All checks passed.✅
"
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
...
...
@@ -548,17 +512,15 @@ if __name__ == "__main__":
print
(
f
"Detected GPU compute capability:
{
arch
}
"
)
assert
float
(
arch
)
>=
9.0
,
"This example only supports GPU with compute capability >= 9.0"
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
32
,
help
=
"Number of heads"
)
parser
.
add_argument
(
"--n_ctx"
,
type
=
int
,
default
=
1024
,
help
=
"Context size"
)
parser
.
add_argument
(
"--d_head_qk"
,
type
=
int
,
default
=
192
,
help
=
"Head dimension for Q/K"
)
parser
.
add_argument
(
"--d_head_v"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension for V"
)
parser
.
add_argument
(
"--causal"
,
action
=
"store_true"
,
help
=
"Causal flag"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
16
,
help
=
"groups"
)
parser
.
add_argument
(
"--use_atomic"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use atomic add for dK/dV"
)
parser
.
add_argument
(
"--use_split"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use split for dK/dV"
)
args
=
parser
.
parse_args
()
# Handle backward compatibility and logic
...
...
@@ -570,5 +532,4 @@ if __name__ == "__main__":
# Default: use atomic
use_atomic
=
True
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
View file @
667632cc
...
...
@@ -7,55 +7,42 @@ import argparse
from
einops
import
rearrange
,
repeat
from
bert_padding
import
pad_input
,
unpad_input
# tilelang.disable_cache()
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
padding_mask
=
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
return
padding_mask
@
tilelang
.
jit
(
out_idx
=
[
5
,
6
],
pass_configs
=
{
out_idx
=
[
5
,
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn_fwd
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
o_shape
=
[
total_q
,
heads
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
Output
:
T
.
Tensor
(
o_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
):
...
...
@@ -102,15 +89,17 @@ def flashattn_fwd(batch,
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
),
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
),
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
),
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
,
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
)
)
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
,
0
,
T
.
Cast
(
accum_dtype
,
-
1e30
)
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_v
):
if
k
*
block_N
+
i
<
k_current_seqlen
:
...
...
@@ -119,6 +108,8 @@ def flashattn_fwd(batch,
V_shared
[
i
,
d
]
=
0.0
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
...
...
@@ -146,12 +137,14 @@ def flashattn_fwd(batch,
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
total_q
,
N_CTX
,
max_seq_len
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
total_q
,
heads
,
dim_v
]
blk
=
32
...
...
@@ -159,7 +152,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
...
...
@@ -199,12 +192,14 @@ def make_dq_layout(dQ):
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
total_q
,
total_kv
,
heads
,
head_kv
,
dim_qk
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
...
...
@@ -221,44 +216,37 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
):
with
T
.
Kernel
(
T
.
ceildiv
(
total_q
,
blk
),
heads
,
threads
=
128
)
as
(
bx
,
by
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dQ
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
with
T
.
Kernel
(
T
.
ceildiv
(
total_kv
,
blk
),
head_kv
,
threads
=
128
)
as
(
bx
,
by
):
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
})
T
.
copy
(
dK
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
}
)
T
.
copy
(
dK
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:])
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_atomic_add
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
do_shape
=
[
total_q
,
heads
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -268,14 +256,13 @@ def flashattn_bwd_atomic_add(batch,
dO
:
T
.
Tensor
(
do_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
...
...
@@ -301,58 +288,54 @@ def flashattn_bwd_atomic_add(batch,
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
}
)
T
.
copy
(
K
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
min
(
T
.
floordiv
(
by
*
block_M
,
block_N
),
T
.
floordiv
(
q_current_seqlen
,
block_N
))
if
is_causal
else
0
loop_st
=
T
.
min
(
T
.
floordiv
(
by
*
block_M
,
block_N
),
T
.
floordiv
(
q_current_seqlen
,
block_N
))
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
q_current_seqlen
,
block_N
)
for
k_base
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
((
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
(
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
,
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
copy
(
dO
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
# dsT: (block_kv, block_q)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -362,49 +345,40 @@ def flashattn_bwd_atomic_add(batch,
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
k_base
*
block_N
+
block_N
,
bx
,
:],
dQ
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
k_base
*
block_N
+
block_N
,
bx
,
:],
dq_shared
,
memory_order
=
"relaxed"
,
use_tma
=
True
)
use_tma
=
True
,
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dV
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dv_shared
,
memory_order
=
"relaxed"
,
use_tma
=
True
)
use_tma
=
True
,
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dK
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dk_shared
,
memory_order
=
"relaxed"
,
use_tma
=
True
)
use_tma
=
True
,
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd_split
(
batch
,
total_q
,
total_kv
,
N_CTX
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
...
...
@@ -412,8 +386,8 @@ def flashattn_bwd_split(batch,
do_shape
=
[
total_q
,
heads
,
dim_v
]
dk_shape
=
[
groups
,
total_kv
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
total_kv
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -423,14 +397,13 @@ def flashattn_bwd_split(batch,
dO
:
T
.
Tensor
(
do_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
N_CTX
],
accum_dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
dk_shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
dv_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
...
...
@@ -455,59 +428,55 @@ def flashattn_bwd_split(batch,
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
}
)
T
.
copy
(
K
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
min
(
T
.
floordiv
(
by
*
block_M
,
block_N
),
T
.
floordiv
(
q_current_seqlen
,
block_N
))
if
is_causal
else
0
loop_st
=
T
.
min
(
T
.
floordiv
(
by
*
block_M
,
block_N
),
T
.
floordiv
(
q_current_seqlen
,
block_N
))
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
q_current_seqlen
,
block_N
)
for
k_base
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
# Note: The padding zero of varlen should be considered in T.copy
T
.
copy
(
Q
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dO
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
copy
(
dO
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
(
k_base
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
((
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
(
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
,
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k_base
*
block_N
:
(
k_base
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -518,62 +487,37 @@ def flashattn_bwd_split(batch,
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
T
.
atomic_add
(
dQ
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
],
memory_order
=
"relaxed"
)
T
.
atomic_add
(
dQ
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
],
memory_order
=
"relaxed"
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk_shared
,
dK
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk_shared
,
dK
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:])
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
=
1
,
use_atomic
=
True
):
def
forward
(
ctx
,
q
,
k
,
v
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
D_HEAD_V
=
v
.
shape
[
-
1
]
block_M
=
128
block_N
=
64
q_unpad
,
indices_q
,
_
,
_
=
unpad_input
(
q
,
(
torch
.
arange
(
N_CTX
,
device
=
q
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
k_unpad
,
indices_k
,
_
,
_
=
unpad_input
(
k
,
(
torch
.
arange
(
N_CTX
,
device
=
k
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
(
torch
.
arange
(
N_CTX
,
device
=
v
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
q_unpad
,
indices_q
,
_
,
_
=
unpad_input
(
q
,
(
torch
.
arange
(
N_CTX
,
device
=
q
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
k_unpad
,
indices_k
,
_
,
_
=
unpad_input
(
k
,
(
torch
.
arange
(
N_CTX
,
device
=
k
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
(
torch
.
arange
(
N_CTX
,
device
=
v
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
total_q
=
q_unpad
.
shape
[
0
]
total_kv
=
k_unpad
.
shape
[
0
]
mod
=
flashattn_fwd
(
BATCH
,
total_q
,
total_kv
,
N_CTX
,
H
,
max_seqlen_q
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
mod
=
flashattn_fwd
(
BATCH
,
total_q
,
total_kv
,
N_CTX
,
H
,
max_seqlen_q
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
o_unpad
,
lse
=
mod
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
)
o
=
pad_input
(
o_unpad
,
indices_q
,
BATCH
,
N_CTX
)
ctx
.
save_for_backward
(
q_unpad
,
k_unpad
,
v_unpad
,
o_unpad
,
lse
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
)
ctx
.
save_for_backward
(
q_unpad
,
k_unpad
,
v_unpad
,
o_unpad
,
lse
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
)
ctx
.
batch
=
BATCH
ctx
.
causal
=
causal
ctx
.
use_atomic
=
use_atomic
...
...
@@ -588,8 +532,7 @@ class _attention(torch.autograd.Function):
N_CTX
=
do
.
shape
[
1
]
q
,
k
,
v
,
o
,
lse_clone
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
=
ctx
.
saved_tensors
# lse_clone = lse.clone()
do_unpad
,
_
,
_
,
_
=
unpad_input
(
do
,
(
torch
.
arange
(
N_CTX
,
device
=
do
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
do_unpad
,
_
,
_
,
_
=
unpad_input
(
do
,
(
torch
.
arange
(
N_CTX
,
device
=
do
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
total_q
,
H
,
D_HEAD_QK
=
q
.
shape
total_kv
,
HEAD_KV
,
D_HEAD_V
=
v
.
shape
groups
=
H
//
HEAD_KV
...
...
@@ -622,7 +565,8 @@ class _attention(torch.autograd.Function):
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
groups
=
groups
,
)
dq
=
torch
.
zeros_like
(
q
,
dtype
=
torch
.
float32
)
dk
=
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
)
dv
=
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
)
...
...
@@ -643,13 +587,13 @@ class _attention(torch.autograd.Function):
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
groups
=
groups
,
)
dq
=
torch
.
zeros_like
(
q
,
dtype
=
torch
.
float32
)
dk
=
torch
.
empty
(
groups
,
*
k
.
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
groups
,
*
v
.
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse_clone
,
delta
,
cu_seqlens_q
,
cu_seqlens_k
,
dq
,
dk
,
dv
)
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dk
,
dv
=
dk
.
sum
(
0
),
dv
.
sum
(
0
)
dq
=
pad_input
(
dq
,
ctx
.
indices_q
,
BATCH
,
N_CTX
)
...
...
@@ -668,15 +612,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
# HQ = HKV * groups
# To handle precision issue
Q
,
K
,
V
=
Q
.
float
(),
K
.
float
(),
V
.
float
()
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
...
...
@@ -684,41 +626,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
if
padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
def
main
(
BATCH
:
int
=
1
,
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
use_atomic
:
bool
=
True
,
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
V
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
dO
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
padding_mask
=
generate_random_padding_mask
(
N_CTX
,
BATCH
,
"cuda"
,
mode
=
"random"
)
seqlens_q
=
padding_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
F
.
pad
(
torch
.
cumsum
(
seqlens_q
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
...
...
@@ -727,8 +663,7 @@ def main(BATCH: int = 1,
# In training backward pass, seqlens_k should be the same as seqlens_q
seqlens_k
,
cu_seqlens_k
,
max_seqlen_k
=
seqlens_q
,
cu_seqlens_q
,
max_seqlen_q
O
=
attention
(
Q
,
K
,
V
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
,
use_atomic
)
O
=
attention
(
Q
,
K
,
V
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
...
...
@@ -770,17 +705,15 @@ if __name__ == "__main__":
print
(
f
"Detected GPU compute capability:
{
arch
}
"
)
assert
float
(
arch
)
>=
9.0
,
"This example only supports GPU with compute capability >= 9.0"
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
32
,
help
=
"Number of heads"
)
parser
.
add_argument
(
"--n_ctx"
,
type
=
int
,
default
=
1024
,
help
=
"Context size"
)
parser
.
add_argument
(
"--d_head_qk"
,
type
=
int
,
default
=
192
,
help
=
"Head dimension for Q/K"
)
parser
.
add_argument
(
"--d_head_v"
,
type
=
int
,
default
=
128
,
help
=
"Head dimension for V"
)
parser
.
add_argument
(
"--causal"
,
action
=
"store_true"
,
help
=
"Causal flag"
)
parser
.
add_argument
(
"--groups"
,
type
=
int
,
default
=
16
,
help
=
"groups"
)
parser
.
add_argument
(
"--use_atomic"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use atomic add for dK/dV"
)
parser
.
add_argument
(
"--use_split"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use split for dK/dV"
)
args
=
parser
.
parse_args
()
# Can be set to True/False for testing
args
.
causal
=
True
...
...
@@ -794,5 +727,4 @@ if __name__ == "__main__":
# Default: use atomic
use_atomic
=
True
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
View file @
667632cc
...
...
@@ -6,17 +6,19 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -40,25 +42,25 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
...
...
@@ -72,21 +74,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
...
...
@@ -103,38 +107,30 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
}
)
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -167,45 +163,39 @@ def flashattn_bwd(batch,
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
wait_wgmma
(
1
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -217,18 +207,17 @@ def flashattn_bwd(batch,
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
,
wg_wait
=
1
)
T
.
wait_wgmma
(
0
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
...
...
@@ -246,7 +235,10 @@ class _attention(torch.autograd.Function):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
(
HEAD_KV
,
D_HEAD_V
,
)
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
...
...
@@ -260,18 +252,7 @@ class _attention(torch.autograd.Function):
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD_V
)
delta
=
mod_prep
(
o
,
do
)
kernel
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
kernel
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
...
...
@@ -294,52 +275,36 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
):
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
V
=
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
dO
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
...
...
@@ -356,7 +321,7 @@ def main(BATCH: int = 1,
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'
All checks passed.✅
'
)
print
(
"
All checks passed.✅
"
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
...
...
@@ -376,13 +341,13 @@ def main(BATCH: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
32
,
help
=
'
Number of heads
'
)
parser
.
add_argument
(
'
--n_ctx
'
,
type
=
int
,
default
=
1024
,
help
=
'
Context size
'
)
parser
.
add_argument
(
'
--d_head_qk
'
,
type
=
int
,
default
=
192
,
help
=
'
Head dimension for Q/K
'
)
parser
.
add_argument
(
'
--d_head_v
'
,
type
=
int
,
default
=
128
,
help
=
'
Head dimension for V
'
)
parser
.
add_argument
(
'
--causal
'
,
action
=
'
store_true
'
,
help
=
'
Causal flag
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
16
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
32
,
help
=
"
Number of heads
"
)
parser
.
add_argument
(
"
--n_ctx
"
,
type
=
int
,
default
=
1024
,
help
=
"
Context size
"
)
parser
.
add_argument
(
"
--d_head_qk
"
,
type
=
int
,
default
=
192
,
help
=
"
Head dimension for Q/K
"
)
parser
.
add_argument
(
"
--d_head_v
"
,
type
=
int
,
default
=
128
,
help
=
"
Head dimension for V
"
)
parser
.
add_argument
(
"
--causal
"
,
action
=
"
store_true
"
,
help
=
"
Causal flag
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
16
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
)
examples/flash_attention/example_gqa_fwd_bshd.py
View file @
667632cc
...
...
@@ -9,7 +9,6 @@ from functools import partial
class
FlashAttentionTuneSpace
:
def
__init__
(
self
,
block_sizes
=
(
64
,
128
,
256
),
...
...
@@ -40,7 +39,7 @@ def get_configs(user_config=None):
warp_M
=
block_M
//
warp_count
warp_N
=
block_N
//
warp_count
if
(
warp_M
%
config
.
warp_alignment
!=
0
or
warp_N
%
config
.
warp_alignment
!=
0
)
:
if
warp_M
%
config
.
warp_alignment
!=
0
or
warp_N
%
config
.
warp_alignment
!=
0
:
continue
shared_mem
=
2
*
config
.
dtype_bytes
*
config
.
dim
*
(
block_M
+
block_N
)
...
...
@@ -48,36 +47,31 @@ def get_configs(user_config=None):
continue
for
num_stages
in
config
.
num_stages_range
:
valid_configs
.
append
({
valid_configs
.
append
(
{
"block_M"
:
block_M
,
"block_N"
:
block_N
,
"num_stages"
:
num_stages
,
"threads"
:
threads
,
})
}
)
return
valid_configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
1
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
1
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
macro
def
MMA0
(
...
...
@@ -90,13 +84,13 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -109,7 +103,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -125,6 +119,8 @@ def flashattn(batch,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -171,25 +167,24 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
...
...
@@ -199,50 +194,34 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
16
,
tune
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
16
,
tune
:
bool
=
False
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
2
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
2
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
,
groups
=
groups
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
...
...
@@ -266,12 +245,12 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
64
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
16
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
64
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
16
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
,
args
.
tune
)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
View file @
667632cc
...
...
@@ -24,9 +24,11 @@ def get_configs():
rep
=
10
,
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
batch
,
heads
,
...
...
@@ -39,12 +41,12 @@ def flashattn(
num_stages
=
0
,
threads
=
128
,
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
macro
def
MMA0
(
...
...
@@ -57,13 +59,13 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -76,7 +78,7 @@ def flashattn(
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -92,6 +94,8 @@ def flashattn(
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -138,30 +142,30 @@ def flashattn(
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
...
...
@@ -171,23 +175,21 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -205,18 +207,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
,
groups
=
groups
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
...
...
@@ -240,12 +232,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
64
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
16
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
64
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
16
,
help
=
"
groups
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
,
args
.
tune
)
examples/flash_attention/example_gqa_fwd_varlen.py
View file @
667632cc
...
...
@@ -26,7 +26,7 @@ def attention_ref(
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
b
,
T
,
Hq
,
D
=
q
.
shape
S
=
k
.
shape
[
1
]
scale
=
(
1.0
/
D
)
**
0.5
scale
=
(
1.0
/
D
)
**
0.5
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
Hq
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
...
...
@@ -54,41 +54,31 @@ def attention_ref(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch_size
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
UQ
,
heads
,
dim
]
kv_shape
=
[
UKV
,
head_kv
,
dim
]
o_shape
=
[
UQ
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
main
(
Q_unpad
:
T
.
Tensor
(
q_shape
,
dtype
),
K_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
"
int32
"
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
"
int32
"
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
T
.
int32
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
T
.
int32
),
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -102,10 +92,12 @@ def flashattn(batch_size,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
})
}
)
batch_idx
=
bz
head_idx
=
by
...
...
@@ -119,43 +111,40 @@ def flashattn(batch_size,
q_current_seqlen
=
q_end_idx
-
q_start_idx
kv_current_seqlen
=
k_end_idx
-
kv_start_idx
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
q_current_seqlen
+
(
bx
+
1
)
*
block_M
,
block_N
),
T
.
ceildiv
(
kv_current_seqlen
,
block_N
)
)
if
is_causal
else
T
.
ceildiv
(
kv_current_seqlen
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
q_current_seqlen
+
(
bx
+
1
)
*
block_M
,
block_N
),
T
.
ceildiv
(
kv_current_seqlen
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
kv_current_seqlen
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K_unpad
[
kv_start_idx
+
k
*
block_N
:
kv_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
T
.
copy
(
K_unpad
[
kv_start_idx
+
k
*
block_N
:
kv_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
<
k
*
block_N
+
j
)
or
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
<
k
*
block_N
+
j
)
or
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
,
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
kv_current_seqlen
),
-
1e9
,
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
...
...
@@ -171,9 +160,7 @@ def flashattn(batch_size,
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V_unpad
[
kv_start_idx
+
k
*
block_N
:
kv_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
T
.
copy
(
V_unpad
[
kv_start_idx
+
k
*
block_N
:
kv_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -188,13 +175,9 @@ def flashattn(batch_size,
return
main
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
q_seqlen
:
int
=
2048
,
k_seqlen
:
int
=
2048
,
dim
:
int
=
128
,
groups
:
int
=
16
,
is_causal
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
q_seqlen
:
int
=
2048
,
k_seqlen
:
int
=
2048
,
dim
:
int
=
128
,
groups
:
int
=
16
,
is_causal
:
bool
=
False
):
assert
heads
%
groups
==
0
,
"heads must be divisible by groups"
flops_per_matmul
=
2.0
*
batch
*
heads
*
q_seqlen
*
k_seqlen
*
dim
...
...
@@ -232,24 +215,12 @@ def main(batch: int = 1,
output_pad_fn
,
_
,
_
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
UKV
=
k_unpad
.
shape
[
0
]
kernel
=
flashattn
(
batch
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
kernel
=
flashattn
(
batch
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
...
...
@@ -264,23 +235,19 @@ def main(batch: int = 1,
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
),
_n_warmup
=
5
,
_n_repeat
=
5
)
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
),
_n_warmup
=
5
,
_n_repeat
=
5
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
64
,
help
=
'
query heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
16
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--q_seqlen
'
,
type
=
int
,
default
=
2048
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--k_seqlen
'
,
type
=
int
,
default
=
2048
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal attention
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
64
,
help
=
"
query heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
16
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--q_seqlen
"
,
type
=
int
,
default
=
2048
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--k_seqlen
"
,
type
=
int
,
default
=
2048
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal attention
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
q_seqlen
,
args
.
k_seqlen
,
args
.
dim
,
args
.
groups
,
args
.
is_causal
)
main
(
args
.
batch
,
args
.
heads
,
args
.
q_seqlen
,
args
.
k_seqlen
,
args
.
dim
,
args
.
groups
,
args
.
is_causal
)
examples/flash_attention/example_mha_bwd_bhsd.py
View file @
667632cc
...
...
@@ -7,14 +7,16 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -39,28 +41,28 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
...
...
@@ -74,21 +76,23 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
blk
=
32
...
...
@@ -105,29 +109,30 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
blk
=
64
...
...
@@ -139,22 +144,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
}
)
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -190,36 +197,39 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
copy
(
K
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
V_shared
)
}
)
T
.
copy
(
K
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
do
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -232,14 +242,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
+
i
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dv_shared
,
dV
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
bx
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
H
,
N_CTX
,
D_HEAD
=
q
.
shape
...
...
@@ -281,15 +290,15 @@ attention = _attention.apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -304,9 +313,7 @@ def main(
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
...
...
@@ -347,10 +354,10 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
32
,
help
=
'
Number of heads
'
)
parser
.
add_argument
(
'
--n_ctx
'
,
type
=
int
,
default
=
1024
,
help
=
'
Context size
'
)
parser
.
add_argument
(
'
--d_head
'
,
type
=
int
,
default
=
64
,
help
=
'
Head dimension
'
)
parser
.
add_argument
(
'
--causal
'
,
type
=
bool
,
default
=
False
,
help
=
'
Causal flag
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
32
,
help
=
"
Number of heads
"
)
parser
.
add_argument
(
"
--n_ctx
"
,
type
=
int
,
default
=
1024
,
help
=
"
Context size
"
)
parser
.
add_argument
(
"
--d_head
"
,
type
=
int
,
default
=
64
,
help
=
"
Head dimension
"
)
parser
.
add_argument
(
"
--causal
"
,
type
=
bool
,
default
=
False
,
help
=
"
Causal flag
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_bwd.py
→
examples/flash_attention/example_mha_bwd
_bshd
.py
View file @
667632cc
...
...
@@ -7,14 +7,16 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -38,25 +40,25 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
...
...
@@ -70,21 +72,23 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
32
...
...
@@ -101,29 +105,30 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
64
...
...
@@ -135,22 +140,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:
(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
}
)
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -186,33 +193,36 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
dQ
:
make_dq_layout
(
dQ
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -225,14 +235,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
...
...
@@ -274,15 +283,15 @@ attention = _attention.apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -297,9 +306,7 @@ def main(
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
...
...
@@ -338,10 +345,10 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
32
,
help
=
'
Number of heads
'
)
parser
.
add_argument
(
'
--n_ctx
'
,
type
=
int
,
default
=
1024
,
help
=
'
Context size
'
)
parser
.
add_argument
(
'
--d_head
'
,
type
=
int
,
default
=
64
,
help
=
'
Head dimension
'
)
parser
.
add_argument
(
'
--causal
'
,
type
=
bool
,
default
=
False
,
help
=
'
Causal flag
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
32
,
help
=
"
Number of heads
"
)
parser
.
add_argument
(
"
--n_ctx
"
,
type
=
int
,
default
=
1024
,
help
=
"
Context size
"
)
parser
.
add_argument
(
"
--d_head
"
,
type
=
int
,
default
=
64
,
help
=
"
Head dimension
"
)
parser
.
add_argument
(
"
--causal
"
,
type
=
bool
,
default
=
False
,
help
=
"
Causal flag
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
→
examples/flash_attention/example_mha_bwd_
bshd_
wgmma_pipelined.py
View file @
667632cc
...
...
@@ -7,14 +7,16 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_fwd
(
...
...
@@ -38,26 +40,26 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
loop_range
=
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
...
...
@@ -71,21 +73,23 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
32
...
...
@@ -102,25 +106,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
T
.
copy
(
O
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:
(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:
(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:
(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
jit
(
pass_configs
=
{
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
}
)
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
flash_bwd
(
...
...
@@ -157,47 +163,43 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
})
}
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
T
.
copy
(
K
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
copy
(
Q
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
wait_wgmma
(
1
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T
.
wait_wgmma
(
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:
(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
...
...
@@ -208,17 +210,16 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
,
wg_wait
=
1
)
T
.
wait_wgmma
(
0
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:
(
by
+
1
)
*
block_M
,
bx
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
...
...
@@ -260,15 +261,15 @@ attention = _attention.apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -283,9 +284,7 @@ def main(
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
Q
=
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
()
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
...
...
@@ -305,7 +304,7 @@ def main(
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'
All checks passed.✅
'
)
print
(
"
All checks passed.✅
"
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
...
...
@@ -323,10 +322,10 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
Batch size
'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
32
,
help
=
'
Number of heads
'
)
parser
.
add_argument
(
'
--n_ctx
'
,
type
=
int
,
default
=
1024
,
help
=
'
Context size
'
)
parser
.
add_argument
(
'
--d_head
'
,
type
=
int
,
default
=
64
,
help
=
'
Head dimension
'
)
parser
.
add_argument
(
'
--causal
'
,
type
=
bool
,
default
=
False
,
help
=
'
Causal flag
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
Batch size
"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
32
,
help
=
"
Number of heads
"
)
parser
.
add_argument
(
"
--n_ctx
"
,
type
=
int
,
default
=
1024
,
help
=
"
Context size
"
)
parser
.
add_argument
(
"
--d_head
"
,
type
=
int
,
default
=
64
,
help
=
"
Head dimension
"
)
parser
.
add_argument
(
"
--causal
"
,
type
=
bool
,
default
=
False
,
help
=
"
Causal flag
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_fwd_bhsd.py
View file @
667632cc
...
...
@@ -15,24 +15,17 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
...
...
@@ -48,14 +41,16 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
# We shall fill -inf for OOB positions
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_kv
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -68,7 +63,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -84,6 +79,10 @@ def flashattn(batch,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -131,43 +130,42 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_
M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -185,18 +183,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
...
...
@@ -221,12 +209,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
1
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
256
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
256
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
64
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
1
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
256
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
256
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
64
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
,
default
=
False
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
View file @
667632cc
...
...
@@ -15,24 +15,17 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
...
...
@@ -48,14 +41,16 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
# We shall fill -inf for OOB positions
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_kv
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -68,7 +63,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -84,6 +79,8 @@ def flashattn(batch,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -131,48 +128,48 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_
M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -190,18 +187,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
...
...
@@ -226,12 +213,12 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
4096
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
4096
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
4096
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
4096
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bshd.py
View file @
667632cc
...
...
@@ -15,22 +15,16 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
macro
def
MMA0
(
...
...
@@ -43,13 +37,14 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
# We shall fill -inf for OOB positions
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -62,7 +57,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -78,6 +73,8 @@ def flashattn(batch,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -124,40 +121,39 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -174,17 +170,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
1
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
1
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
...
...
@@ -208,11 +195,11 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
View file @
667632cc
...
...
@@ -15,22 +15,16 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
macro
def
MMA0
(
...
...
@@ -43,13 +37,14 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
# We shall fill -inf for OOB positions
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
>=
seq_len
,
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -62,7 +57,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -78,6 +73,8 @@ def flashattn(batch,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -124,45 +121,45 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -179,17 +176,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
...
...
@@ -213,11 +201,11 @@ def main(
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
4096
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
4096
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_varlen.py
View file @
667632cc
...
...
@@ -47,7 +47,7 @@ def attention_ref(
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
# log2(e)
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
...
...
@@ -68,41 +68,32 @@ def attention_ref(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
32
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch_size
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
32
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
UQ
,
heads
,
dim
]
k_shape
=
[
UKV
,
heads
,
dim
]
v_shape
=
[
UKV
,
heads
,
dim
]
o_shape
=
[
UQ
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
main
(
Q_unpad
:
T
.
Tensor
(
q_shape
,
dtype
),
K_unpad
:
T
.
Tensor
(
k_shape
,
dtype
),
V_unpad
:
T
.
Tensor
(
v_shape
,
dtype
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
"
int32
"
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
"
int32
"
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
T
.
int32
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
T
.
int32
),
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
,
"shared"
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
...
...
@@ -151,15 +142,17 @@ def flashattn(batch_size,
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
,
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -167,6 +160,8 @@ def flashattn(batch_size,
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -242,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
# unpadded query length
UK
=
k_unpad
.
shape
[
0
]
# unpadded key length
...
...
@@ -285,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
64
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
2048
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
64
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
2048
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
)
examples/flash_attention/test_example_flash_attention.py
View file @
667632cc
...
...
@@ -2,7 +2,7 @@ import tilelang.testing
import
example_gqa_bwd
import
example_gqa_bwd_wgmma_pipelined
import
example_mha_bwd
import
example_mha_bwd
_bshd
import
example_mha_bwd_bhsd
import
example_mha_fwd_bhsd_wgmma_pipelined
import
example_gqa_fwd_bshd
...
...
@@ -10,7 +10,7 @@ import example_mha_fwd_bshd
import
example_gqa_fwd_bshd_wgmma_pipelined
import
example_mha_fwd_bshd_wgmma_pipelined
import
example_mha_fwd_varlen
import
example_mha_bwd_wgmma_pipelined
import
example_mha_bwd_
bshd_
wgmma_pipelined
import
example_mha_fwd_bhsd
import
example_gqa_bwd_tma_reduce_varlen
...
...
@@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined():
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd
():
example_mha_bwd
.
main
(
example_mha_bwd
_bshd
.
main
(
BATCH
=
1
,
H
=
16
,
N_CTX
=
512
,
...
...
@@ -56,20 +56,18 @@ def test_example_mha_bwd_bhsd():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_bwd_wgmma_pipelined
():
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
,
H
=
32
,
N_CTX
=
256
,
D_HEAD
=
64
,
causal
=
False
)
example_mha_bwd_
bshd_
wgmma_pipelined
.
main
(
BATCH
=
1
,
H
=
32
,
N_CTX
=
256
,
D_HEAD
=
64
,
causal
=
False
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_gqa_fwd_bshd_wgmma_pipelined
():
example_gqa_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
example_gqa_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_gqa_fwd_bshd
():
example_gqa_fwd_bshd
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
example_gqa_fwd_bshd
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
...
...
examples/flash_attention/varlen_utils.py
View file @
667632cc
...
...
@@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
padding_mask
=
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
...
...
@@ -39,15 +31,12 @@ def generate_qkv(q,
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
...
...
@@ -55,8 +44,7 @@ def generate_qkv(q,
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
...
...
@@ -67,8 +55,7 @@ def generate_qkv(q,
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
...
...
@@ -84,8 +71,7 @@ def generate_qkv(q,
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment