Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b1fbbd83
Commit
b1fbbd83
authored
Aug 29, 2023
by
Tri Dao
Browse files
Implement splitKV attention
parent
7a983df7
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
7 deletions
+155
-7
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/generate_kernels.py
csrc/flash_attn/src/generate_kernels.py
+16
-6
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+13
-1
setup.py
setup.py
+16
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+103
-0
No files found.
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/generate_kernels.py
View file @
b1fbbd83
...
...
@@ -16,14 +16,21 @@ DTYPE_MAP = {
SM
=
[
80
]
# Sm80 kernels support up to
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
KERNEL_IMPL_TEMPLATE_FWD
=
"""
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_BWD
=
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
...
...
@@ -44,10 +51,14 @@ class Kernel:
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
el
se
:
el
if
self
.
direction
==
"bwd"
:
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
else
:
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
@
property
def
filename
(
self
)
->
str
:
...
...
@@ -56,7 +67,7 @@ class Kernel:
def
get_all_kernels
()
->
List
[
Kernel
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
for
direction
in
[
"fwd"
,
"bwd"
]:
for
direction
in
[
"fwd"
,
"bwd"
,
"fwd_split"
]:
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
direction
=
direction
)
...
...
@@ -65,8 +76,7 @@ def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
\n
"""
include
=
f
'#include "flash_
{
kernel
.
direction
}
_launch_template.h"
\n
'
(
autogen_dir
/
kernel
.
filename
).
write_text
(
prelude
+
include
+
kernel
.
template
)
(
autogen_dir
/
kernel
.
filename
).
write_text
(
prelude
+
kernel
.
template
)
def
main
(
output_dir
:
Optional
[
str
])
->
None
:
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
b1fbbd83
...
...
@@ -113,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
Element
>
;
using
SmemCopyAtomOaccum
=
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
;
static
constexpr
int
kSmemQCount
=
size
(
SmemLayoutQ
{});
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
...
...
@@ -158,6 +159,17 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemLayoutAtomOaccum
=
std
::
conditional_t
<
kBlockKSmem
==
32
,
Layout
<
Shape
<
_16
,
_8
>
,
// Thread layout, 8 threads per row
Stride
<
_8
,
_1
>>
,
Layout
<
Shape
<
_8
,
_16
>
,
// Thread layout, 16 threads per row
Stride
<
_16
,
_1
>>
>
;
using
GmemTiledCopyOaccum
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
...
...
setup.py
View file @
b1fbbd83
...
...
@@ -173,6 +173,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
...
...
tests/test_flash_attn.py
View file @
b1fbbd83
...
...
@@ -1367,6 +1367,109 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
1e-5
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
3
,
1024
),
(
1
,
339
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
16
,
100000
),
(
128
,
128
),
(
256
,
256
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
nheads
=
12
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
return_attn_probs
=
True
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment