Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangkx1
tilelang
Commits
bc2d5632
Commit
bc2d5632
authored
Jan 15, 2026
by
root
Browse files
init
parents
Pipeline
#3222
failed with stages
in 0 seconds
Changes
257
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6385 additions
and
0 deletions
+6385
-0
examples/flash_attention/example_gqa_bwd.py
examples/flash_attention/example_gqa_bwd.py
+550
-0
examples/flash_attention/example_gqa_bwd_tma_reduce.py
examples/flash_attention/example_gqa_bwd_tma_reduce.py
+570
-0
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
...ples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
+792
-0
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
+388
-0
examples/flash_attention/example_gqa_fwd_bshd.py
examples/flash_attention/example_gqa_fwd_bshd.py
+277
-0
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
...s/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
+251
-0
examples/flash_attention/example_gqa_fwd_varlen.py
examples/flash_attention/example_gqa_fwd_varlen.py
+274
-0
examples/flash_attention/example_mha_bwd.py
examples/flash_attention/example_mha_bwd.py
+347
-0
examples/flash_attention/example_mha_bwd_bhsd.py
examples/flash_attention/example_mha_bwd_bhsd.py
+356
-0
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
+332
-0
examples/flash_attention/example_mha_fwd_bhsd.py
examples/flash_attention/example_mha_fwd_bhsd.py
+232
-0
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
...s/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
+237
-0
examples/flash_attention/example_mha_fwd_bshd.py
examples/flash_attention/example_mha_fwd_bshd.py
+218
-0
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
...s/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
+223
-0
examples/flash_attention/example_mha_fwd_varlen.py
examples/flash_attention/example_mha_fwd_varlen.py
+294
-0
examples/flash_attention/test_example_flash_attention.py
examples/flash_attention/test_example_flash_attention.py
+91
-0
examples/flash_attention/varlen_utils.py
examples/flash_attention/varlen_utils.py
+122
-0
examples/flash_decoding/README.md
examples/flash_decoding/README.md
+1
-0
examples/flash_decoding/example_gqa_decode.py
examples/flash_decoding/example_gqa_decode.py
+501
-0
examples/flash_decoding/example_mha_inference.py
examples/flash_decoding/example_mha_inference.py
+329
-0
No files found.
Too many changes to show.
To preserve performance only
257 of 257+
files are displayed.
Plain diff
Email patch
examples/flash_attention/example_gqa_bwd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim_qk
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
blk
=
64
@
T
.
prim_func
def
flash_bwd_post
(
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dQ_out
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dk_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
dk_shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
dv_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
D_HEAD_V
=
v
.
shape
[
-
1
]
block_M
=
128
block_N
=
64
mod
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
o
,
lse
=
mod
(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
ctx
.
use_atomic
=
use_atomic
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
128
block_N
=
32
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD_V
)
mod_post
=
flashattn_bwd_postprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
)
delta
=
mod_prep
(
o
,
do
)
if
ctx
.
use_atomic
:
kernel
=
flashattn_bwd_atomic_add
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
dq
=
torch
.
zeros
(
shape_q
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
zeros
(
shape_k
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dv
=
torch
.
zeros
(
shape_v
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
mod_post
(
dq
)
dk
=
dk
.
to
(
torch
.
float16
)
dv
=
dv
.
to
(
torch
.
float16
)
else
:
kernel
=
flashattn_bwd_split
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
# sum after kernel
shape_v
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
# sum after kernel
dq
=
torch
.
zeros
(
shape_q
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
empty
(
shape_k
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape_v
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
mod_post
(
dq
)
dk
,
dv
=
dk
.
sum
(
0
),
dv
.
sum
(
0
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
,
groups
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
torch
.
testing
.
assert_close
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'All checks passed.✅'
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
args
=
parser
.
parse_args
()
# Handle backward compatibility and logic
if
args
.
use_split
:
use_atomic
=
False
elif
args
.
use_atomic
:
use_atomic
=
True
else
:
# Default: use atomic
use_atomic
=
True
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_tma_reduce.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
from
tilelang.contrib
import
nvcc
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# bshd -> bhld to use tma reduction instruction
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
h
,
l
,
d
])
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_postprocess
(
batch
,
heads
,
head_kv
,
seq_len
,
dim_qk
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
blk
=
64
@
T
.
prim_func
def
flash_bwd_post
(
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
dQ_out
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
dK_out
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
dV_out
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
head_kv
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
})
T
.
copy
(
dK
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
accum_dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
dK
:
make_dq_layout
(
dK
),
dV
:
make_dq_layout
(
dV
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
,
use_tma
=
True
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
,
use_tma
=
True
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
,
use_tma
=
True
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dk_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
batch
,
seq_len
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
dk_shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
dv_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dv_shared
,
dV
[
bx
%
groups
,
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dk
,
dK
[
bx
%
groups
,
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:])
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
D_HEAD_V
=
v
.
shape
[
-
1
]
block_M
=
128
block_N
=
64
mod
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
o
,
lse
=
mod
(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
ctx
.
use_atomic
=
use_atomic
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
128
block_N
=
32
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD_V
)
mod_post
=
flashattn_bwd_postprocess
(
BATCH
,
H
,
HEAD_KV
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
)
delta
=
mod_prep
(
o
,
do
)
if
ctx
.
use_atomic
:
kernel
=
flashattn_bwd_atomic_add
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
dq
=
torch
.
zeros
(
shape_q
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
zeros
(
shape_k
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dv
=
torch
.
zeros
(
shape_v
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
else
:
kernel
=
flashattn_bwd_split
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
# sum after kernel
shape_v
=
[
groups
,
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
# sum after kernel
dq
=
torch
.
zeros
(
shape_q
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
empty
(
shape_k
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape_v
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dk
,
dv
=
dk
.
sum
(
0
),
dv
.
sum
(
0
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
,
groups
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
torch
.
testing
.
assert_close
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'All checks passed.✅'
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
arch
=
nvcc
.
get_target_compute_version
()
print
(
f
"Detected GPU compute capability:
{
arch
}
"
)
assert
float
(
arch
)
>=
9.0
,
"This example only supports GPU with compute capability >= 9.0"
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
args
=
parser
.
parse_args
()
# Handle backward compatibility and logic
if
args
.
use_split
:
use_atomic
=
False
elif
args
.
use_atomic
:
use_atomic
=
True
else
:
# Default: use atomic
use_atomic
=
True
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
from
tilelang.contrib
import
nvcc
import
argparse
from
einops
import
rearrange
,
repeat
from
bert_padding
import
pad_input
,
unpad_input
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
@
tilelang
.
jit
(
out_idx
=
[
5
,
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
total_q
,
total_kv
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
o_shape
=
[
total_q
,
heads
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
Output
:
T
.
Tensor
(
o_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
q_start_idx
=
cu_seqlens_q
[
bz
]
k_start_idx
=
cu_seqlens_k
[
bz
]
q_end_idx
=
cu_seqlens_q
[
bz
+
1
]
k_end_idx
=
cu_seqlens_k
[
bz
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_qk
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
Q
[
q_start_idx
+
bx
*
block_M
+
i
,
by
,
d
]
else
:
Q_shared
[
i
,
d
]
=
0.0
T
.
fill
(
acc_o
,
0.0
)
T
.
fill
(
logsum
,
0.0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_qk
):
if
k
*
block_N
+
i
<
k_current_seqlen
:
K_shared
[
i
,
d
]
=
K
[
k_start_idx
+
k
*
block_N
+
i
,
by
//
groups
,
d
]
else
:
K_shared
[
i
,
d
]
=
0.0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
),
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
<
q_current_seqlen
and
k
*
block_N
+
j
<
k_current_seqlen
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_v
):
if
k
*
block_N
+
i
<
k_current_seqlen
:
V_shared
[
i
,
d
]
=
V
[
k_start_idx
+
k
*
block_N
+
i
,
by
//
groups
,
d
]
else
:
V_shared
[
i
,
d
]
=
0.0
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_v
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Output
[
q_start_idx
+
bx
*
block_M
+
i
,
by
,
d
]
=
acc_o
[
i
,
d
]
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
if
bx
*
block_M
+
i
<
q_current_seqlen
:
lse
[
q_start_idx
+
bx
*
block_M
+
i
,
by
]
=
logsum
[
i
]
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
total_q
,
max_seq_len
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
total_q
,
heads
,
dim_v
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
Delta
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
q_start_idx
=
cu_seqlens_q
[
bz
]
q_end_idx
=
cu_seqlens_q
[
bz
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
if
by
*
blk
+
i
<
q_current_seqlen
and
k
*
blk
+
j
<
dim_v
:
o
[
i
,
j
]
=
O
[
q_start_idx
+
by
*
blk
+
i
,
bx
,
k
*
blk
+
j
]
do
[
i
,
j
]
=
dO
[
q_start_idx
+
by
*
blk
+
i
,
bx
,
k
*
blk
+
j
]
else
:
o
[
i
,
j
]
=
0.0
do
[
i
,
j
]
=
0.0
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
for
i
in
T
.
Parallel
(
blk
):
if
by
*
blk
+
i
<
q_current_seqlen
:
Delta
[
q_start_idx
+
by
*
blk
+
i
,
bx
]
=
delta
[
i
]
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# bshd -> bhld to use tma reduction instruction
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
h
,
l
,
d
])
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
,
5
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_postprocess
(
total_q
,
total_kv
,
heads
,
head_kv
,
dim_qk
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
blk
=
64
@
T
.
prim_func
def
flash_bwd_post
(
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
dQ_out
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
dK_out
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
dV_out
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
total_q
,
blk
),
heads
,
threads
=
128
)
as
(
bx
,
by
):
# T.annotate_layout({dQ: make_dq_layout(dQ)})
T
.
copy
(
dQ
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
with
T
.
Kernel
(
T
.
ceildiv
(
total_kv
,
blk
),
head_kv
,
threads
=
128
)
as
(
bx
,
by
):
# T.annotate_layout({
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
# })
T
.
copy
(
dK
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dK_out
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
T
.
copy
(
dV
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dV_out
[
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:])
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_atomic_add
(
batch
,
total_q
,
total_kv
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
do_shape
=
[
total_q
,
heads
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
do_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
q_start_idx
=
cu_seqlens_q
[
bz
]
k_start_idx
=
cu_seqlens_k
[
bz
]
q_end_idx
=
cu_seqlens_q
[
bz
+
1
]
k_end_idx
=
cu_seqlens_k
[
bz
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
T
.
annotate_layout
({
# dQ: make_dq_layout(dQ),
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_qk
):
if
by
*
block_M
+
i
<
k_current_seqlen
:
K_shared
[
i
,
d
]
=
K
[
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
V_shared
[
i
,
d
]
=
V
[
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
else
:
K_shared
[
i
,
d
]
=
0.0
V_shared
[
i
,
d
]
=
0.0
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
(
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
)
loop_ed
=
T
.
ceildiv
(
q_current_seqlen
,
block_N
)
for
k_base
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_qk
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
q
[
i
,
d
]
=
Q
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
d
]
else
:
q
[
i
,
d
]
=
0.0
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
in
T
.
Parallel
(
block_N
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
lse_shared
[
i
]
=
lse
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
]
else
:
lse_shared
[
i
]
=
0.0
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
((
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_v
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
do
[
i
,
d
]
=
dO
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
d
]
else
:
do
[
i
,
d
]
=
0.0
T
.
clear
(
dsT
)
# dsT: (block_kv, block_q)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
in
T
.
Parallel
(
block_N
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
delta
[
i
]
=
Delta
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
]
else
:
delta
[
i
]
=
0.0
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
T
.
atomic_add
(
dQ
[
q_start_idx
+
k_base
*
block_N
:
q_start_idx
+
k_base
*
block_N
+
block_N
,
bx
,
:],
dq
,
memory_order
=
"release"
)
T
.
atomic_add
(
dV
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dv
,
memory_order
=
"release"
)
T
.
atomic_add
(
dK
[
k_start_idx
+
by
*
block_M
:
k_start_idx
+
by
*
block_M
+
block_M
,
bx
//
groups
,
:],
dk
,
memory_order
=
"release"
)
return
flash_bwd
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_split
(
batch
,
total_q
,
total_kv
,
heads
,
max_seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
total_q
,
heads
,
dim_qk
]
k_shape
=
[
total_kv
,
head_kv
,
dim_qk
]
v_shape
=
[
total_kv
,
head_kv
,
dim_v
]
do_shape
=
[
total_q
,
heads
,
dim_v
]
dk_shape
=
[
groups
,
total_kv
,
head_kv
,
dim_qk
]
# sum after kernel
dv_shape
=
[
groups
,
total_kv
,
head_kv
,
dim_v
]
# sum after kernel
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
do_shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
total_q
,
heads
],
accum_dtype
),
# type: ignore
cu_seqlens_q
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
dk_shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
dv_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
max_seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
q_start_idx
=
cu_seqlens_q
[
bz
]
k_start_idx
=
cu_seqlens_k
[
bz
]
q_end_idx
=
cu_seqlens_q
[
bz
+
1
]
k_end_idx
=
cu_seqlens_k
[
bz
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
T
.
annotate_layout
({
# dQ: make_dq_layout(dQ),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_qk
):
if
by
*
block_M
+
i
<
k_current_seqlen
:
K_shared
[
i
,
d
]
=
K
[
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
V_shared
[
i
,
d
]
=
V
[
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
else
:
K_shared
[
i
,
d
]
=
0.0
V_shared
[
i
,
d
]
=
0.0
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
(
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
)
loop_ed
=
T
.
ceildiv
(
q_current_seqlen
,
block_N
)
for
k_base
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_qk
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
q
[
i
,
d
]
=
Q
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
d
]
else
:
q
[
i
,
d
]
=
0.0
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim_v
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
do
[
i
,
d
]
=
dO
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
d
]
else
:
do
[
i
,
d
]
=
0.0
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
in
T
.
Parallel
(
block_N
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
lse_shared
[
i
]
=
lse
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
]
else
:
lse_shared
[
i
]
=
0.0
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
((
by
*
block_M
+
i
<=
k_base
*
block_N
+
j
)
and
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
),
qkT
[
i
,
j
],
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<
k_current_seqlen
and
k_base
*
block_N
+
j
<
q_current_seqlen
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
in
T
.
Parallel
(
block_N
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
delta
[
i
]
=
Delta
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
]
else
:
delta
[
i
]
=
0.0
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim_qk
):
if
k_base
*
block_N
+
i
<
q_current_seqlen
:
T
.
atomic_add
(
dQ
[
q_start_idx
+
k_base
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
],
memory_order
=
"release"
)
T
.
copy
(
dv
,
dv_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_v
):
if
by
*
block_M
+
i
<
k_current_seqlen
:
dV
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
=
dv
[
i
,
d
]
T
.
copy
(
dk
,
dk_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim_qk
):
if
by
*
block_M
+
i
<
k_current_seqlen
:
dK
[
bx
%
groups
,
k_start_idx
+
by
*
block_M
+
i
,
bx
//
groups
,
d
]
=
dk
[
i
,
d
]
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
D_HEAD_V
=
v
.
shape
[
-
1
]
block_M
=
128
block_N
=
64
q_unpad
,
indices_q
,
_
,
_
=
unpad_input
(
q
,
(
torch
.
arange
(
N_CTX
,
device
=
q
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
k_unpad
,
indices_k
,
_
,
_
=
unpad_input
(
k
,
(
torch
.
arange
(
N_CTX
,
device
=
k
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
(
torch
.
arange
(
N_CTX
,
device
=
v
.
device
).
unsqueeze
(
0
)
<
seqlens_k
.
unsqueeze
(
1
)))
total_q
=
q_unpad
.
shape
[
0
]
total_kv
=
k_unpad
.
shape
[
0
]
mod
=
flashattn_fwd
(
BATCH
,
total_q
,
total_kv
,
H
,
max_seqlen_q
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
o_unpad
,
lse
=
mod
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
)
o
=
pad_input
(
o_unpad
,
indices_q
,
BATCH
,
N_CTX
)
ctx
.
save_for_backward
(
q_unpad
,
k_unpad
,
v_unpad
,
o_unpad
,
lse
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
)
ctx
.
causal
=
causal
ctx
.
use_atomic
=
use_atomic
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
indices_q
=
indices_q
ctx
.
indices_k
=
indices_k
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
N_CTX
=
do
.
shape
[
1
]
q
,
k
,
v
,
o
,
lse
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
=
ctx
.
saved_tensors
do_unpad
,
_
,
_
,
_
=
unpad_input
(
do
,
(
torch
.
arange
(
N_CTX
,
device
=
do
.
device
).
unsqueeze
(
0
)
<
seqlens_q
.
unsqueeze
(
1
)))
total_q
,
H
,
D_HEAD_QK
=
q
.
shape
total_kv
,
HEAD_KV
,
D_HEAD_V
=
v
.
shape
groups
=
H
//
HEAD_KV
BATCH
=
len
(
cu_seqlens_q
)
-
1
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do_unpad
,
q
,
k
,
v
,
o
)]
block_M
=
128
block_N
=
32
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
total_q
,
ctx
.
max_seqlen_q
,
D_HEAD_V
)
mod_post
=
flashattn_bwd_postprocess
(
total_q
,
total_kv
,
H
,
HEAD_KV
,
D_HEAD_QK
,
D_HEAD_V
)
delta
=
mod_prep
(
o
,
do
,
cu_seqlens_q
)
if
ctx
.
use_atomic
:
kernel
=
flashattn_bwd_atomic_add
(
BATCH
,
total_q
,
total_kv
,
H
,
ctx
.
max_seqlen_q
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
dq
=
torch
.
zeros_like
(
q
,
dtype
=
torch
.
float32
)
dk
=
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
)
dv
=
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
cu_seqlens_q
,
cu_seqlens_k
,
dq
,
dk
,
dv
)
dq
,
dk
,
dv
=
mod_post
(
dq
,
dk
,
dv
)
else
:
kernel
=
flashattn_bwd_split
(
BATCH
,
total_q
,
total_kv
,
H
,
ctx
.
max_seqlen_q
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
dq
=
torch
.
zeros_like
(
q
,
dtype
=
torch
.
float32
)
dk
=
torch
.
empty
(
groups
,
*
k
.
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
groups
,
*
v
.
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
cu_seqlens_q
,
cu_seqlens_k
,
dq
,
dk
,
dv
)
dq
,
_
,
_
=
mod_post
(
dq
,
torch
.
zeros_like
(
k
,
dtype
=
torch
.
float32
),
torch
.
zeros_like
(
v
,
dtype
=
torch
.
float32
))
dk
,
dv
=
dk
.
sum
(
0
),
dv
.
sum
(
0
)
dq
=
pad_input
(
dq
,
ctx
.
indices_q
,
BATCH
,
N_CTX
)
dk
=
pad_input
(
dk
,
ctx
.
indices_k
,
BATCH
,
N_CTX
)
dv
=
pad_input
(
dv
,
ctx
.
indices_k
,
BATCH
,
N_CTX
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
padding_mask
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
# To handle precision issue
Q
,
K
,
V
=
Q
.
float
(),
K
.
float
(),
V
.
float
()
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
if
padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
,
use_atomic
:
bool
=
True
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
padding_mask
=
generate_random_padding_mask
(
N_CTX
,
BATCH
,
"cuda"
,
mode
=
"random"
)
seqlens_q
=
padding_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
F
.
pad
(
torch
.
cumsum
(
seqlens_q
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
max_seqlen_q
=
seqlens_q
.
max
().
item
()
# In training backward pass, seqlens_k should be the same as seqlens_q
seqlens_k
,
cu_seqlens_k
,
max_seqlen_k
=
seqlens_q
,
cu_seqlens_q
,
max_seqlen_q
O
=
attention
(
Q
,
K
,
V
,
seqlens_q
,
seqlens_k
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
causal
,
groups
,
use_atomic
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
padding_mask
,
causal
,
groups
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
torch
.
testing
.
assert_close
(
O
,
O_ref
.
half
(),
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'All checks passed.✅'
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
arch
=
nvcc
.
get_target_compute_version
()
print
(
f
"Detected GPU compute capability:
{
arch
}
"
)
assert
float
(
arch
)
>=
9.0
,
"This example only supports GPU with compute capability >= 9.0"
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--use_atomic'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use atomic add for dK/dV'
)
parser
.
add_argument
(
'--use_split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use split for dK/dV'
)
args
=
parser
.
parse_args
()
# Handle backward compatibility and logic
if
args
.
use_split
:
use_atomic
=
False
elif
args
.
use_atomic
:
use_atomic
=
True
else
:
# Default: use split
use_atomic
=
False
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
,
use_atomic
)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
groups
=
1
):
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim_v
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim_v
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim_v
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim_qk
,
dim_v
,
is_causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
1
):
sm_scale
=
(
1.0
/
dim_qk
)
**
0.5
scale
=
(
1.0
/
dim_qk
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim_qk
]
k_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_qk
]
v_shape
=
[
batch
,
seq_len
,
head_kv
,
dim_v
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
k_shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
v_shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
([
batch
,
seq_len
,
heads
,
dim_v
],
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
q_shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
k_shape
,
accum_dtype
),
# type: ignore
dV
:
T
.
Tensor
(
v_shape
,
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
q
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim_v
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim_qk
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim_qk
],
accum_dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim_qk
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim_v
],
accum_dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim_qk
],
accum_dtype
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
wait_wgmma
(
1
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
wait_wgmma
(
0
)
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=
1
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
,
wg_wait
=
1
)
T
.
wait_wgmma
(
0
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
atomic_add
(
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
atomic_add
(
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
//
groups
,
:],
dk_shared
)
return
flash_bwd
@
torch
.
compile
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
,
groups
=
1
,
use_atomic
=
True
):
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
D_HEAD_V
=
v
.
shape
[
-
1
]
block_M
=
128
block_N
=
64
mod
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
causal
,
block_M
,
block_N
,
groups
)
o
,
lse
=
mod
(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
ctx
.
use_atomic
=
use_atomic
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
=
q
.
shape
HEAD_KV
,
D_HEAD_V
,
=
v
.
shape
[
-
2
],
v
.
shape
[
-
1
]
groups
=
H
//
HEAD_KV
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
128
block_N
=
32
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD_V
)
delta
=
mod_prep
(
o
,
do
)
kernel
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD_QK
,
D_HEAD_V
,
ctx
.
causal
,
block_M
,
block_N
,
threads
=
256
,
num_stages
=
2
,
groups
=
groups
)
shape_q
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
]
shape_k
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_QK
]
shape_v
=
[
BATCH
,
N_CTX
,
HEAD_KV
,
D_HEAD_V
]
dq
=
torch
.
zeros
(
shape_q
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
zeros
(
shape_k
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dv
=
torch
.
zeros
(
shape_v
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
dq
.
to
(
torch
.
float16
)
dk
=
dk
.
to
(
torch
.
float16
)
dv
=
dv
.
to
(
torch
.
float16
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim_qk
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim_qk
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
1
,
H
:
int
=
32
,
N_CTX
:
int
=
256
,
D_HEAD_QK
:
int
=
192
,
D_HEAD_V
:
int
=
128
,
groups
:
int
=
16
,
causal
:
bool
=
False
):
flops_per_qk
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_QK
flops_per_v
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD_V
total_flops
=
3
*
flops_per_qk
+
2
*
flops_per_v
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
head_kv
=
H
//
groups
K
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_QK
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
V
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
head_kv
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
dO
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD_V
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
O
=
attention
(
Q
,
K
,
V
,
causal
,
groups
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
,
groups
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
torch
.
testing
.
assert_close
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'All checks passed.✅'
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head_qk'
,
type
=
int
,
default
=
192
,
help
=
'Head dimension for Q/K'
)
parser
.
add_argument
(
'--d_head_v'
,
type
=
int
,
default
=
128
,
help
=
'Head dimension for V'
)
parser
.
add_argument
(
'--causal'
,
action
=
'store_true'
,
help
=
'Causal flag'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head_qk
,
args
.
d_head_v
,
args
.
groups
,
args
.
causal
)
examples/flash_attention/example_gqa_fwd_bshd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
class
FlashAttentionTuneSpace
:
def
__init__
(
self
,
block_sizes
=
(
64
,
128
,
256
),
thread_options
=
(
128
,
256
,
512
),
num_stages_range
=
(
2
,
3
),
max_shared_mem
=
100
*
1024
,
warp_alignment
=
16
,
dim
=
128
,
dtype_bytes
=
2
,
):
self
.
block_sizes
=
block_sizes
self
.
thread_options
=
thread_options
self
.
num_stages_range
=
num_stages_range
self
.
max_shared_mem
=
max_shared_mem
self
.
warp_alignment
=
warp_alignment
self
.
dim
=
dim
self
.
dtype_bytes
=
dtype_bytes
def
get_configs
(
user_config
=
None
):
config
=
user_config
or
FlashAttentionTuneSpace
()
valid_configs
=
[]
for
block_M
,
block_N
in
itertools
.
product
(
config
.
block_sizes
,
repeat
=
2
):
for
threads
in
config
.
thread_options
:
assert
threads
%
32
==
0
warp_count
=
threads
//
32
warp_M
=
block_M
//
warp_count
warp_N
=
block_N
//
warp_count
if
(
warp_M
%
config
.
warp_alignment
!=
0
or
warp_N
%
config
.
warp_alignment
!=
0
):
continue
shared_mem
=
2
*
config
.
dtype_bytes
*
config
.
dim
*
(
block_M
+
block_N
)
if
shared_mem
>
config
.
max_shared_mem
:
continue
for
num_stages
in
config
.
num_stages_range
:
valid_configs
.
append
({
"block_M"
:
block_M
,
"block_N"
:
block_N
,
"num_stages"
:
num_stages
,
"threads"
:
threads
,
})
return
valid_configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
1
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
16
,
tune
:
bool
=
False
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
2
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
,
groups
=
groups
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
,
args
.
tune
)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
128
],
block_N
=
[
128
],
num_stages
=
[
2
],
threads
=
[
256
],
)
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
,
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
1
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
,
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
batch
,
seq_len
,
heads
,
dim
]
kv_shape
=
[
batch
,
seq_len
,
head_kv
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
//
groups
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, K.size(2):
{
K
.
size
(
2
)
}
, groups:
{
groups
}
"
assert
Q
.
size
(
2
)
==
V
.
size
(
2
)
*
groups
,
f
"Q.size(2):
{
Q
.
size
(
2
)
}
, V.size(2):
{
V
.
size
(
2
)
}
, groups:
{
groups
}
"
dim
=
Q
.
size
(
-
1
)
K
=
K
.
repeat_interleave
(
groups
,
dim
=
2
)
V
=
V
.
repeat_interleave
(
groups
,
dim
=
2
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
groups
:
int
=
16
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
groups
=
groups
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
,
groups
=
groups
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
groups
,
args
.
tune
)
examples/flash_attention/example_gqa_fwd_varlen.py
0 → 100644
View file @
bc2d5632
# ruff: noqa
import
argparse
import
torch
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
from
einops
import
rearrange
,
repeat
from
tilelang.profiler
import
do_bench
from
varlen_utils
import
generate_random_padding_mask
,
generate_qkv
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
upcast
=
True
,
):
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
=
scores
*
scale
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
head_kv
=
heads
//
groups
q_shape
=
[
UQ
,
heads
,
dim
]
kv_shape
=
[
UKV
,
head_kv
,
dim
]
o_shape
=
[
UQ
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
Q_unpad
:
T
.
Tensor
(
q_shape
,
dtype
),
K_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_unpad
:
T
.
Tensor
(
kv_shape
,
dtype
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
batch_idx
=
bz
head_idx
=
by
kv_head_idx
=
head_idx
//
groups
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
T
.
copy
(
Q_unpad
[
q_start_idx
+
bx
*
block_M
:
q_start_idx
+
(
bx
+
1
)
*
block_M
,
head_idx
,
:],
Q_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
>=
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K_unpad
[
k_start_idx
+
k
*
block_N
:
k_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
K_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
k_current_seqlen
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V_unpad
[
v_start_idx
+
k
*
block_N
:
v_start_idx
+
(
k
+
1
)
*
block_N
,
kv_head_idx
,
:],
V_shared
)
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
>=
v_current_seqlen
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Output_unpad
[
q_start_idx
+
bx
*
block_M
+
i
,
head_idx
,
d
]
=
O_shared
[
i
,
d
]
return
main
def
main
(
batch
:
int
=
1
,
heads
:
int
=
64
,
q_seqlen
:
int
=
2048
,
k_seqlen
:
int
=
2048
,
dim
:
int
=
128
,
groups
:
int
=
16
,
is_causal
:
bool
=
False
):
assert
heads
%
groups
==
0
,
"heads must be divisible by groups"
flops_per_matmul
=
2.0
*
batch
*
heads
*
q_seqlen
*
k_seqlen
*
dim
total_flops
=
2
*
flops_per_matmul
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
causal
:
total_flops
*=
0.5
tilelang
.
testing
.
set_random_seed
(
0
)
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
head_kv
=
heads
//
groups
q
=
torch
.
randn
(
batch
,
q_seqlen
,
heads
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch
,
k_seqlen
,
head_kv
,
dim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
query_padding_mask
=
generate_random_padding_mask
(
q_seqlen
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
k_seqlen
,
batch
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
_
,
_
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
UKV
=
k_unpad
.
shape
[
0
]
kernel
=
flashattn
(
batch
,
groups
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
_
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
query_padding_mask
,
key_padding_mask
=
key_padding_mask
,
causal
=
is_causal
,
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'query heads'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
16
,
help
=
'groups'
)
parser
.
add_argument
(
'--q_seqlen'
,
type
=
int
,
default
=
2048
,
help
=
'query sequence length'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
2048
,
help
=
'key/value sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'head dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal attention'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
q_seqlen
,
args
.
k_seqlen
,
args
.
dim
,
args
.
groups
,
args
.
is_causal
)
examples/flash_attention/example_mha_bwd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
l
,
h
,
d
:
[
b
,
l
//
8
,
h
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
64
@
T
.
prim_func
def
flash_bwd_post
(
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dQ_out
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
dQ_out
[
bz
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
by
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
+
i
,
bx
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
block_M
=
64
block_N
=
64
if
D_HEAD
<=
128
else
32
o
,
lse
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
causal
,
block_M
,
block_N
)(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
64
block_N
=
64
if
D_HEAD
<=
64
else
32
kernel_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD
)
kernel_post
=
flashattn_bwd_postprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD
)
delta
=
kernel_prep
(
o
,
do
)
kernel
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
ctx
.
causal
,
block_M
,
block_N
)
shape
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD
]
dq
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
kernel_post
(
dq
)
return
dq
,
dk
,
dv
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
8
,
H
:
int
=
32
,
N_CTX
:
int
=
1024
,
D_HEAD
:
int
=
64
,
causal
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
O
=
attention
(
Q
,
K
,
V
,
causal
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
64
,
help
=
'Head dimension'
)
parser
.
add_argument
(
'--causal'
,
type
=
bool
,
default
=
False
,
help
=
'Causal flag'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_bwd_bhsd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
def
make_dq_layout
(
dQ
):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return
T
.
Layout
(
dQ
.
shape
,
lambda
b
,
h
,
l
,
d
:
[
b
,
h
,
l
//
8
,
d
//
8
,
(
d
%
2
),
4
*
(
l
%
8
)
+
(
d
%
8
)
//
2
])
@
tilelang
.
jit
(
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_postprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
blk
=
64
@
T
.
prim_func
def
flash_bwd_post
(
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dQ_out
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
blk
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
)})
T
.
copy
(
dQ
[
bz
,
by
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
:],
dQ_out
[
bz
,
by
,
bx
*
blk
:(
bx
+
1
)
*
blk
,
:],
)
return
flash_bwd_post
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
T
.
annotate_layout
({
dQ
:
make_dq_layout
(
dQ
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
})
T
.
copy
(
K
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
copy
(
dO
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_N
,
dim
):
T
.
atomic_add
(
dQ
[
bz
,
bx
,
k
*
block_N
+
i
,
j
],
dq
[
i
,
j
])
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
bx
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
H
,
N_CTX
,
D_HEAD
=
q
.
shape
block_M
=
64
block_N
=
64
if
D_HEAD
<=
128
else
32
o
,
lse
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
causal
,
block_M
,
block_N
)(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
H
,
N_CTX
,
D_HEAD
=
q
.
shape
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
64
block_N
=
64
if
D_HEAD
<=
64
else
32
kernel_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD
)
kernel_post
=
flashattn_bwd_postprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD
)
delta
=
kernel_prep
(
o
,
do
)
kernel
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
ctx
.
causal
,
block_M
,
block_N
)
shape
=
[
BATCH
,
H
,
N_CTX
,
D_HEAD
]
dq
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
kernel
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
kernel_post
(
dq
)
return
dq
,
dk
,
dv
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bhqd,bhkd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bhkd->bhqd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
8
,
H
:
int
=
32
,
N_CTX
:
int
=
1024
,
D_HEAD
:
int
=
64
,
causal
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
O
=
attention
(
Q
,
K
,
V
,
causal
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
from
tilelang.profiler
import
do_bench
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
64
,
help
=
'Head dimension'
)
parser
.
add_argument
(
'--causal'
,
type
=
bool
,
default
=
False
,
help
=
'Causal flag'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
from
tilelang.profiler
import
do_bench
import
argparse
@
tilelang
.
jit
(
out_idx
=
[
3
,
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_fwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_fwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Output
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
)})
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
)
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
lse
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
])
return
flash_fwd
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd_preprocess
(
batch
,
heads
,
seq_len
,
dim
):
dtype
=
"float16"
accum_dtype
=
"float"
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
blk
=
32
@
T
.
prim_func
def
flash_bwd_prep
(
O
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
blk
),
batch
)
as
(
bx
,
by
,
bz
):
o
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
do
=
T
.
alloc_fragment
([
blk
,
blk
],
dtype
)
acc
=
T
.
alloc_fragment
([
blk
,
blk
],
accum_dtype
)
delta
=
T
.
alloc_fragment
([
blk
],
accum_dtype
)
T
.
clear
(
acc
)
for
k
in
range
(
T
.
ceildiv
(
dim
,
blk
)):
T
.
copy
(
O
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
o
)
T
.
copy
(
dO
[
bz
,
by
*
blk
:(
by
+
1
)
*
blk
,
bx
,
k
*
blk
:(
k
+
1
)
*
blk
],
do
)
for
i
,
j
in
T
.
Parallel
(
blk
,
blk
):
acc
[
i
,
j
]
+=
o
[
i
,
j
]
*
do
[
i
,
j
]
T
.
reduce_sum
(
acc
,
delta
,
1
)
T
.
copy
(
delta
,
Delta
[
bz
,
bx
,
by
*
blk
:(
by
+
1
)
*
blk
])
return
flash_bwd_prep
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn_bwd
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
,
block_N
):
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
flash_bwd
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
K
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
V
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dO
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
lse
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
Delta
:
T
.
Tensor
([
batch
,
heads
,
seq_len
],
accum_dtype
),
# type: ignore
dQ
:
T
.
Tensor
(
shape
,
accum_dtype
),
# type: ignore
dK
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
dV
:
T
.
Tensor
(
shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
heads
,
T
.
ceildiv
(
seq_len
,
block_M
),
batch
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
K_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dsT_shared
=
T
.
alloc_shared
([
block_M
,
block_N
],
dtype
)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
qkT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
dsT
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
qkT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
dsT_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
lse_shared
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
delta
=
T
.
alloc_shared
([
block_N
],
accum_dtype
)
do
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
dv
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dk
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
dq
=
T
.
alloc_fragment
([
block_N
,
dim
],
accum_dtype
)
dv_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dk_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
dq_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
accum_dtype
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dk_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dk_shared
),
dq_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dq_shared
),
})
T
.
copy
(
K
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
K_shared
)
T
.
copy
(
V
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:],
V_shared
)
T
.
clear
(
dv
)
T
.
clear
(
dk
)
loop_st
=
T
.
floordiv
(
by
*
block_M
,
block_N
)
if
is_causal
else
0
loop_ed
=
T
.
ceildiv
(
seq_len
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_st
,
loop_ed
,
num_stages
=
2
):
T
.
copy
(
Q
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
q
)
T
.
clear
(
qkT
)
T
.
gemm
(
K_shared
,
q
,
qkT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
dO
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
do
)
T
.
clear
(
dsT
)
T
.
gemm
(
V_shared
,
do
,
dsT
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
wait_wgmma
(
1
)
T
.
copy
(
lse
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
lse_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
exp2
(
qkT
[
i
,
j
]
*
scale
-
lse_shared
[
j
])
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
qkT
[
i
,
j
]
=
T
.
if_then_else
(
by
*
block_M
+
i
<=
k
*
block_N
+
j
,
qkT
[
i
,
j
],
0
)
T
.
wait_wgmma
(
0
)
T
.
copy
(
qkT
,
qkT_cast
)
T
.
gemm
(
qkT_cast
,
do
,
dv
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=-
1
)
T
.
copy
(
Delta
[
bz
,
bx
,
k
*
block_N
:(
k
+
1
)
*
block_N
],
delta
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
dsT_cast
[
i
,
j
]
=
qkT
[
i
,
j
]
*
(
dsT
[
i
,
j
]
-
delta
[
j
])
*
sm_scale
T
.
gemm
(
dsT_cast
,
q
,
dk
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
,
wg_wait
=
1
)
T
.
copy
(
dsT_cast
,
dsT_shared
)
T
.
clear
(
dq
)
T
.
gemm
(
dsT_shared
,
K_shared
,
dq
,
transpose_A
=
True
,
wg_wait
=
1
)
T
.
wait_wgmma
(
0
)
T
.
copy
(
dq
,
dq_shared
)
T
.
atomic_add
(
dQ
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
bx
,
:],
dq_shared
)
T
.
copy
(
dv
,
dv_shared
)
T
.
copy
(
dk
,
dk_shared
)
T
.
copy
(
dv_shared
,
dV
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:])
T
.
copy
(
dk_shared
,
dK
[
bz
,
by
*
block_M
:(
by
+
1
)
*
block_M
,
bx
,
:])
return
flash_bwd
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
):
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
block_M
=
64
block_N
=
64
if
D_HEAD
<=
128
else
32
mod
=
flashattn_fwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
causal
,
block_M
,
block_N
)
o
,
lse
=
mod
(
q
,
k
,
v
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
BATCH
,
N_CTX
,
H
,
D_HEAD
=
q
.
shape
def
maybe_contiguous
(
x
):
if
x
.
stride
(
-
1
)
!=
1
:
return
x
.
contiguous
()
return
x
do
,
q
,
k
,
v
,
o
=
[
maybe_contiguous
(
x
)
for
x
in
(
do
,
q
,
k
,
v
,
o
)]
block_M
=
128
block_N
=
128
if
D_HEAD
<=
64
else
32
mod_prep
=
flashattn_bwd_preprocess
(
BATCH
,
H
,
N_CTX
,
D_HEAD
)
delta
=
mod_prep
(
o
,
do
)
mod
=
flashattn_bwd
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
ctx
.
causal
,
block_M
,
block_N
)
shape
=
[
BATCH
,
N_CTX
,
H
,
D_HEAD
]
dq
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
dk
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
dv
=
torch
.
empty
(
shape
,
dtype
=
torch
.
float16
,
device
=
q
.
device
)
mod
(
q
,
k
,
v
,
do
,
lse
,
delta
,
dq
,
dk
,
dv
)
dq
=
dq
.
to
(
torch
.
float16
)
return
dq
,
dk
,
dv
,
None
attention
=
_attention
.
apply
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
BATCH
:
int
=
8
,
H
:
int
=
32
,
N_CTX
:
int
=
1024
,
D_HEAD
:
int
=
64
,
causal
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
BATCH
*
H
*
N_CTX
*
N_CTX
*
D_HEAD
total_flops
=
5
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
Q
=
(
torch
.
empty
(
BATCH
,
N_CTX
,
H
,
D_HEAD
,
dtype
=
torch
.
half
,
device
=
"cuda"
).
normal_
().
requires_grad_
())
K
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
V
=
torch
.
empty_like
(
Q
).
normal_
().
requires_grad_
()
dO
=
torch
.
randn_like
(
Q
)
O
=
attention
(
Q
,
K
,
V
,
causal
)
O
.
backward
(
dO
,
retain_graph
=
True
)
dQ
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV
,
V
.
grad
=
V
.
grad
.
clone
(),
None
O_ref
=
ref_program
(
Q
,
K
,
V
,
causal
)
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
dQ_ref
,
Q
.
grad
=
Q
.
grad
.
clone
(),
None
dK_ref
,
K
.
grad
=
K
.
grad
.
clone
(),
None
dV_ref
,
V
.
grad
=
V
.
grad
.
clone
(),
None
assert
torch
.
allclose
(
O
,
O_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dV
,
dV_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dK
,
dK_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
dQ
,
dQ_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
'All checks passed.✅'
)
def
run
():
O_ref
.
backward
(
dO
,
retain_graph
=
True
)
def
run1
():
O
.
backward
(
dO
,
retain_graph
=
True
)
latency
=
do_bench
(
run
,
warmup
=
500
)
print
(
"torch: {:.2f} ms"
.
format
(
latency
))
print
(
"torch: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
run1
,
warmup
=
500
)
print
(
"tilelang: {:.2f} ms"
.
format
(
latency
))
print
(
"tilelang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
32
,
help
=
'Number of heads'
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
1024
,
help
=
'Context size'
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
64
,
help
=
'Head dimension'
)
parser
.
add_argument
(
'--causal'
,
type
=
bool
,
default
=
False
,
help
=
'Causal flag'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
h
,
args
.
n_ctx
,
args
.
d_head
,
args
.
causal
)
examples/flash_attention/example_mha_fwd_bhsd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
128
],
block_N
=
[
128
],
num_stages
=
[
2
],
threads
=
[
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bhqd,bhkd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bhkd->bhqd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
1
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
64
,
is_causal
:
bool
=
False
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
1
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
256
,
help
=
'query sequence length'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
256
,
help
=
'key/value sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
64
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
128
],
block_N
=
[
128
],
num_stages
=
[
2
],
threads
=
[
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bhqd,bhkd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bhkd->bhqd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'query sequence length'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'key/value sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bshd.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
],
block_N
=
[
64
],
num_stages
=
[
1
],
threads
=
[
128
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
8
,
heads
:
int
=
32
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
1
,
threads
=
128
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
best_result
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
)
best_latency
=
best_result
.
latency
best_config
=
best_result
.
config
ref_latency
=
best_result
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
128
],
block_N
=
[
128
],
num_stages
=
[
2
],
threads
=
[
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
seq_len
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
by
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
],
[
12
],
[
13
]]):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_len
=
Q
.
size
(
1
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
device
=
scores
.
device
))
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
8
,
heads
:
int
=
32
,
seq_len
:
int
=
4096
,
dim
:
int
=
128
,
is_causal
:
bool
=
False
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
)
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_len
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
examples/flash_attention/example_mha_fwd_varlen.py
0 → 100644
View file @
bc2d5632
# ruff: noqa
import
torch
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
import
argparse
import
torch
from
einops
import
rearrange
,
repeat
from
varlen_utils
import
generate_random_padding_mask
,
generate_qkv
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
upcast
=
True
,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
# log2(e)
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
# scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0)
scores
=
scores
*
scale
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention
,
v
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
32
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
UQ
,
heads
,
dim
]
k_shape
=
[
UKV
,
heads
,
dim
]
v_shape
=
[
UKV
,
heads
,
dim
]
o_shape
=
[
UQ
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
Q_unpad
:
T
.
Tensor
(
q_shape
,
dtype
),
K_unpad
:
T
.
Tensor
(
k_shape
,
dtype
),
V_unpad
:
T
.
Tensor
(
v_shape
,
dtype
),
cu_seqlens_q
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
cu_seqlens_k
:
T
.
Tensor
([
batch_size
+
1
],
"int32"
),
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
,
"shared"
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
,
"shared"
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
batch_idx
=
bz
head_idx
=
by
q_start_idx
=
cu_seqlens_q
[
batch_idx
]
k_start_idx
=
cu_seqlens_k
[
batch_idx
]
v_start_idx
=
cu_seqlens_k
[
batch_idx
]
q_end_idx
=
cu_seqlens_q
[
batch_idx
+
1
]
k_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
v_end_idx
=
cu_seqlens_k
[
batch_idx
+
1
]
q_current_seqlen
=
q_end_idx
-
q_start_idx
k_current_seqlen
=
k_end_idx
-
k_start_idx
v_current_seqlen
=
v_end_idx
-
v_start_idx
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Q_shared
[
i
,
d
]
=
Q_unpad
[
q_start_idx
+
bx
*
block_M
+
i
,
head_idx
,
d
]
else
:
Q_shared
[
i
,
d
]
=
0
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
(
k_current_seqlen
,
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
# Q * K
for
i
,
d
in
T
.
Parallel
(
block_N
,
dim
):
if
k
*
block_N
+
i
<
k_current_seqlen
:
K_shared
[
i
,
d
]
=
K_unpad
[
k_start_idx
+
k
*
block_N
+
i
,
head_idx
,
d
]
else
:
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
# Softmax
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
# Rescale
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
# V * softmax(Q * K)
for
i
,
d
in
T
.
grid
(
block_N
,
dim
):
if
k
*
block_N
+
i
<
v_current_seqlen
:
V_shared
[
i
,
d
]
=
V_unpad
[
v_start_idx
+
k
*
block_N
+
i
,
head_idx
,
d
]
else
:
V_shared
[
i
,
d
]
=
0
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
for
i
,
d
in
T
.
Parallel
(
block_M
,
dim
):
if
bx
*
block_M
+
i
<
q_current_seqlen
:
Output_unpad
[
q_start_idx
+
bx
*
block_M
+
i
,
head_idx
,
d
]
=
O_shared
[
i
,
d
]
return
main
def
main
(
batch
:
int
=
8
,
heads
:
int
=
64
,
seq_len
:
int
=
2048
,
dim
:
int
=
128
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_len
*
seq_len
*
dim
total_flops
=
2
*
flops_per_matmul
tilelang
.
testing
.
set_random_seed
(
0
)
causal
=
False
if
causal
:
total_flops
*=
0.5
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
window_size
=
(
-
1
,
-
1
)
q
=
torch
.
randn
(
batch
,
seq_len
,
heads
,
dim
,
dtype
=
dtype
,
requires_grad
=
True
).
to
(
device
)
k
=
torch
.
randn
(
batch
,
seq_len
,
heads
,
dim
,
dtype
=
dtype
,
requires_grad
=
True
).
to
(
device
)
v
=
torch
.
randn
(
batch
,
seq_len
,
heads
,
dim
,
dtype
=
dtype
,
requires_grad
=
True
).
to
(
device
)
query_padding_mask
=
generate_random_padding_mask
(
seq_len
,
batch
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seq_len
,
batch
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
# unpadded query length
UK
=
k_unpad
.
shape
[
0
]
# unpadded key length
UKV
=
k_unpad
.
shape
[
0
]
# unpadded query key length
kernel
=
flashattn
(
batch
,
UQ
,
UKV
,
heads
,
dim
,
causal
)
out_unpad
=
kernel
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
_
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
import
flash_attn
fla_out_unpad
=
flash_attn
.
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
causal
=
causal
,
)
fla_out
=
output_pad_fn
(
fla_out_unpad
)
torch
.
testing
.
assert_close
(
out
,
fla_out
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
64
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
2048
,
help
=
'sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
)
examples/flash_attention/test_example_flash_attention.py
0 → 100644
View file @
bc2d5632
import
tilelang.testing
import
example_gqa_bwd
import
example_gqa_bwd_wgmma_pipelined
import
example_mha_bwd
import
example_mha_bwd_bhsd
import
example_mha_fwd_bhsd_wgmma_pipelined
import
example_gqa_fwd_bshd
import
example_mha_fwd_bshd
import
example_gqa_fwd_bshd_wgmma_pipelined
import
example_mha_fwd_bshd_wgmma_pipelined
import
example_mha_fwd_varlen
import
example_mha_bwd_wgmma_pipelined
import
example_mha_fwd_bhsd
import
example_gqa_bwd_tma_reduce_varlen
@
tilelang
.
testing
.
requires_cuda
def
test_example_gqa_bwd_tma_reduce_varlen
():
example_gqa_bwd_tma_reduce_varlen
.
main
()
@
tilelang
.
testing
.
requires_cuda
def
test_example_gqa_bwd
():
example_gqa_bwd
.
main
()
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_gqa_bwd_wgmma_pipelined
():
example_gqa_bwd_wgmma_pipelined
.
main
()
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd
():
example_mha_bwd
.
main
(
BATCH
=
1
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_bwd_bhsd
():
example_mha_bwd_bhsd
.
main
(
BATCH
=
1
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_bwd_wgmma_pipelined
():
example_mha_bwd_wgmma_pipelined
.
main
(
BATCH
=
1
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_gqa_fwd_bshd_wgmma_pipelined
():
example_gqa_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_gqa_fwd_bshd
():
example_gqa_fwd_bshd
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_fwd_bhsd_wgmma_pipelined
():
example_mha_fwd_bhsd_wgmma_pipelined
.
main
()
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_fwd_bhsd
():
example_mha_fwd_bhsd
.
main
()
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_mha_fwd_bshd_wgmma_pipelined
():
example_mha_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
32
,
seq_len
=
256
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_fwd_bshd
():
example_mha_fwd_bshd
.
main
(
batch
=
1
,
seq_len
=
256
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_mha_fwd_varlen
():
example_mha_fwd_varlen
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/flash_attention/varlen_utils.py
0 → 100644
View file @
bc2d5632
# ruff: noqa
import
torch
from
einops
import
rearrange
,
repeat
from
bert_padding
import
pad_input
,
unpad_input
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
examples/flash_decoding/README.md
0 → 100644
View file @
bc2d5632
# Flash Decoding
examples/flash_decoding/example_gqa_decode.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
from
einops
import
rearrange
,
einsum
import
argparse
import
itertools
from
functools
import
lru_cache
from
typing
import
Tuple
,
Dict
torch
.
random
.
manual_seed
(
0
)
def
get_configs
():
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
2
,
4
,
8
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
return
configs
@
lru_cache
(
maxsize
=
1
)
def
get_heuristic_config
()
->
Tuple
[
Dict
,
int
]:
# Get CUDA device properties
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
"CUDA is not available"
)
device
=
torch
.
cuda
.
current_device
()
sm_major
,
sm_minor
=
torch
.
cuda
.
get_device_capability
(
device
)
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
==
89
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
16
,
num_stages
=
0
,
threads
=
128
)
else
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
16
,
num_stages
=
2
,
threads
=
128
)
return
cfg
,
sm_version
# TODO(lei): fix warp specialized and tma lower pass
def
get_pass_configs
():
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
get_pass_configs
())
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_v
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
groups
part_shape
=
[
batch
,
heads
,
num_split
,
dim
]
valid_block_H
=
min
(
block_H
,
kv_group_num
)
valid_block_N
=
min
(
block_N
,
seqlen_kv
//
num_split
)
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
valid_block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
mask_local
=
T
.
alloc_fragment
([
block_N
],
"uint8"
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
bid
=
bx
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
valid_block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
mask_local
=
T
.
alloc_fragment
([
block_N
],
"uint8"
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
bid
=
bx
hid
=
by
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
],
mask_local
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
for
i
in
T
.
Parallel
(
block_H
):
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim
],
accum_dtype
)
lse_local
=
T
.
alloc_fragment
([
num_split
,
128
],
dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
# lse_local: (local_id, thread_id)
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
})
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
for
k
,
j
in
T
.
Parallel
(
num_split
,
128
):
lse_local
[
k
,
j
]
=
glse
[
bz
,
by
,
k
]
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
True
)
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
k
in
T
.
serial
(
num_split
):
for
i
in
T
.
Parallel
(
dim
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
for
i
in
T
.
Parallel
(
dim
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
for
i
in
T
.
Parallel
(
dim
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
def
flashattn_gqa_decode_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
mask
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
mask
,
Output
)
if
num_split
>
1
:
return
flashattn_gqa_decode_split
else
:
return
flashattn_gqa_decode_no_split
def
ref_program
(
query
,
key
,
value
,
mask
,
glse
,
Output_partial
):
# """
# Inputs:
# - query (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, groups, dim]
# - value (Tensor): [batch, seqlen_kv, groups, dim]
# - mask (Tensor): [batch, seqlen_kv, groups]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim
=
query
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'b n h d -> b h n d'
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
'b n h d -> b h n d'
)
# [batch_size, groups, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
if
mask
is
not
None
:
mask
=
rearrange
(
mask
,
'b s h -> b h s'
)
mask
=
mask
.
unsqueeze
(
1
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
def
flash_split_ref
(
Q
,
K
,
V
,
mask
):
num_split
=
16
batch
=
Q
.
size
(
0
)
nheads
=
Q
.
size
(
1
)
groups
=
K
.
size
(
2
)
dim
=
Q
.
size
(
-
1
)
block_N
=
32
seqlen_kv
=
K
.
size
(
1
)
num_head_groups
=
nheads
//
groups
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
gacc_o
=
torch
.
empty
((
num_split
,
batch
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
Q_
=
Q
*
scale
Q_
=
rearrange
(
Q_
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'-inf'
))
scores_max_prev
.
fill_
(
float
(
'-inf'
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bghd,bkhd->bghk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, nheads, block_N]
if
mask
is
not
None
:
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
'b s h -> b h s'
)
mask_local
=
mask_local
.
unsqueeze
(
1
)
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
'-inf'
))
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
acc_o
*=
scores_scale
[:,
:,
:,
None
]
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_o
+=
torch
.
einsum
(
'bghk,bkhd->bghd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o_out
=
rearrange
(
acc_o
,
'b g h d->b (h g) d'
)
logsum_out
=
rearrange
(
logsum
,
'b g h->b (h g)'
)
acc_o_out
/=
logsum_out
[:,
:,
None
]
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
'b g h->b (h g)'
)
gacc_o
[
ks
,
:,
:,
:]
=
acc_o_out
glogsum
[
ks
,
:,
:]
=
logsum_out
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
)
def
reduce_ref
(
Q
,
K
,
V
,
mask
,
glse
,
Output_partial
):
num_split
=
16
o
=
torch
.
empty_like
(
Output_partial
[:,
:,
0
,
:]).
fill_
(
0
)
lse_logsum
=
torch
.
empty_like
(
glse
[:,
:,
0
]).
fill_
(
0
)
# [batch, heads]
lse_max
=
glse
.
max
(
dim
=
2
,
keepdim
=
False
).
values
for
ks
in
range
(
num_split
):
lse
=
glse
[:,
:,
ks
]
lse_logsum
+=
torch
.
exp2
(
lse
-
lse_max
)
lse_logsum
=
torch
.
log2
(
lse_logsum
)
+
lse_max
for
ks
in
range
(
num_split
):
lse
=
glse
[:,
:,
ks
]
scale
=
torch
.
exp2
(
lse
-
lse_logsum
)
# [batch, heads]
o
+=
Output_partial
[:,
:,
ks
,
:]
*
scale
[:,
:,
None
]
return
o
.
to
(
torch
.
float16
)
def
ref_split_program
(
Q
,
K
,
V
,
mask
,
glse
=
None
,
Output_partial
=
None
):
glse_
,
Output_partial_
=
flash_split_ref
(
Q
,
K
,
V
,
mask
)
return
reduce_ref
(
Q
,
K
,
V
,
mask
,
glse_
,
Output_partial_
)
def
print_red_warning
(
msg
):
print
(
f
"
\033
[91m
{
msg
}
\033
[0m"
)
def
calc_sim
(
x
,
y
,
name
=
"tensor"
):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero'
)
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
def
assert_similar
(
x
,
y
,
eps
=
1e-2
,
name
=
"tensor"
,
assert_
=
False
,
print_
=
True
):
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
if
assert_
:
raise
AssertionError
(
f
'
{
name
}
Error:
{
diff
}
'
)
else
:
if
print_
:
print
(
f
'passed:
{
name
}
diff=
{
diff
}
'
)
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
batch
,
heads
,
groups
,
kv_seqlen
,
dim
=
batch
,
heads
,
groups
,
kv_seqlen
,
dim
qk_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
total_flops
=
qk_flops
+
pv_flops
if
(
not
tune
):
config
,
sm_version
=
get_heuristic_config
()
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
,
**
config
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
q
=
torch
.
randn
(
batch
,
heads
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
batch
,
kv_seqlen
,
groups
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
batch
,
kv_seqlen
,
groups
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
mask
=
torch
.
randint
(
0
,
2
,
(
batch
,
kv_seqlen
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
glse
=
torch
.
empty
(
batch
,
heads
,
16
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
Output_partial
=
torch
.
empty
(
batch
,
heads
,
16
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
o
=
kernel
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
o_ref
=
ref_program
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
o_ref_split
=
ref_split_program
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
print
(
o
)
print
(
o_ref
)
assert_similar
(
o
,
o_ref
,
name
=
"o_ref"
)
assert_similar
(
o_ref_split
,
o_ref
,
name
=
"o_ref_split"
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program
,
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
else
:
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--groups'
,
type
=
int
,
default
=
8
,
help
=
'groups'
)
parser
.
add_argument
(
'--kv_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kv sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
kv_seqlen
,
args
.
dim
,
args
.
tune
)
examples/flash_decoding/example_mha_inference.py
0 → 100644
View file @
bc2d5632
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
from
functools
import
partial
num_split
=
4
@
tilelang
.
jit
(
out_idx
=
[
5
])
def
flashattn
(
batch
,
heads
,
seqlen_q
,
seqlen_kv
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
seqlen_q
,
heads
,
dim
]
shape_kv
=
[
batch
,
seqlen_kv
,
heads
,
dim
]
part_shape
=
[
batch
,
seqlen_q
,
heads
,
num_split
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
mid
:
T
.
int32
,
hid
:
T
.
int32
,
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
# TODO: Handle causal split case
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
hid
:
T
.
int32
,
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
mid
=
bx
hid
=
by
%
heads
bid
=
by
//
heads
sid
=
bz
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# TODO: Handle causal split case
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
(
(
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
(
seqlen_kv
//
num_split
),
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
mid
,
hid
,
bid
,
sid
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
hid
,
bid
,
sid
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
block_M
,
dim
],
dtype
)
po_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
o_accum_local
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
o_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
lse_local
=
T
.
alloc_fragment
([
num_split
,
block_M
],
dtype
)
lse_local_split
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
})
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
copy
(
glse
[
bz
,
by
,
:,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
],
lse_local
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
False
)
for
k
in
T
.
Pipelined
(
num_split
):
T
.
copy
(
lse_local
[
k
,
:],
lse_local_split
)
for
i
in
T
.
Parallel
(
block_M
):
lse_logsum_local
[
i
]
+=
T
.
exp2
(
lse_local_split
[
i
]
-
lse_max_local
[
i
])
for
i
in
T
.
Parallel
(
block_M
):
lse_logsum_local
[
i
]
=
T
.
log2
(
lse_logsum_local
[
i
])
+
lse_max_local
[
i
]
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
2
):
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
po_shared
,
po_local
)
for
i
in
T
.
Parallel
(
block_M
):
lse_local_split
[
i
]
=
lse_local
[
k
,
i
]
for
i
in
T
.
Parallel
(
block_M
):
scale_local
[
i
]
=
T
.
exp2
(
lse_local_split
[
i
]
-
lse_logsum_local
[
i
])
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
o_accum_local
[
i
,
j
]
+=
po_local
[
i
,
j
]
*
scale_local
[
i
]
T
.
copy
(
o_accum_local
,
o_shared
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
@
T
.
prim_func
def
flashattn_mha_inference
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
# [batch, seqlen_q, heads, num_split, dim]
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
flashattn_mha_inference
def
ref_program
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
assert
causal
is
False
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
attention_weights
,
V
)
return
output
def
reduce_ref
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
o
=
torch
.
empty_like
(
Output_partial
[:,
:,
:,
0
,
:]).
fill_
(
0
)
lse_logsum
=
torch
.
empty_like
(
glse
[:,
:,
0
,
:]).
fill_
(
0
)
# [batch, seqlen_q, heads]
lse_max
=
glse
.
max
(
dim
=
2
,
keepdim
=
False
).
values
for
ks
in
range
(
num_split
):
lse
=
glse
[:,
:,
ks
,
:]
lse_logsum
+=
torch
.
exp2
(
lse
-
lse_max
)
lse_logsum
=
torch
.
log2
(
lse_logsum
)
+
lse_max
for
ks
in
range
(
num_split
):
lse
=
glse
[:,
:,
ks
,
:]
scale
=
torch
.
exp2
(
lse
-
lse_logsum
)
# [batch, heads, seqlen_q]
o
+=
Output_partial
[:,
:,
:,
ks
,
:]
*
scale
[:,
:,
:,
None
].
transpose
(
1
,
2
)
return
o
.
to
(
torch
.
float16
)
def
flash_split_ref
(
Q
,
K
,
V
,
causal
):
# [batch, seqlen_q, heads, dim]
batch
=
Q
.
size
(
0
)
block_M
=
Q
.
size
(
1
)
nheads
=
Q
.
size
(
2
)
dim
=
Q
.
size
(
3
)
block_N
=
128
seqlen_kv
=
K
.
size
(
1
)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
gacc_o
=
torch
.
empty
((
num_split
,
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
,
block_M
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
Q_
=
Q
*
scale
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'-inf'
))
scores_max_prev
.
fill_
(
float
(
'-inf'
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, seqlen, nheads, block_N]
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [blockM]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
acc_o
*=
scores_scale
[:,
:,
:,
None
].
transpose
(
1
,
2
)
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
acc_o
+=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o
/=
logsum
[:,
:,
:,
None
].
transpose
(
1
,
2
)
logsum
=
torch
.
log2
(
logsum
)
+
scores_max
gacc_o
[
ks
,
:,
:,
:,
:]
=
acc_o
glogsum
[
ks
,
:,
:,
:]
=
logsum
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
def
main
():
BATCH
,
H
,
Q_CTX
,
KV_CTX
,
D_HEAD
=
1
,
32
,
128
,
8192
,
128
causal
=
False
flops_per_matmul
=
2.0
*
BATCH
*
H
*
Q_CTX
*
KV_CTX
*
D_HEAD
total_flops
=
2
*
flops_per_matmul
if
causal
:
total_flops
*=
0.5
BLOCK_M
=
128
BLOCK_N
=
64
# if D_HEAD <= 128 else 32
kernel
=
flashattn
(
BATCH
,
H
,
Q_CTX
,
KV_CTX
,
D_HEAD
,
causal
,
BLOCK_M
,
BLOCK_N
)
ref_fn
=
partial
(
ref_program
,
causal
=
causal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
.
assert_allclose
(
ref_fn
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks passed!"
)
latency
=
profiler
.
do_bench
(
ref_fn
,
warmup
=
500
)
print
(
"{:.2f} ms"
.
format
(
latency
))
print
(
"{:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
profiler
.
do_bench
(
n_warmup
=
10
,
n_repeat
=
10
)
print
(
"{:.4f} ms"
.
format
(
latency
))
print
(
"{:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
if
__name__
==
"__main__"
:
main
()
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment