Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5badfb78
Commit
5badfb78
authored
Oct 13, 2022
by
Tri Dao
Browse files
Implement attention kernel that splits the batch into two
parent
f515c77f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
260 additions
and
33 deletions
+260
-33
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+20
-18
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+2
-0
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+1
-1
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+1
-1
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+128
-10
tests/test_flash_attn.py
tests/test_flash_attn.py
+108
-3
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
5badfb78
...
...
@@ -45,9 +45,9 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
const
at
::
Tensor
q
,
const
at
::
Tensor
k
,
const
at
::
Tensor
v
,
at
::
Tensor
out
,
void
*
cu_seqlens_q_d
,
void
*
cu_seqlens_k_d
,
void
*
o_packed_d
,
void
*
o_tmp_d
,
void
*
s_d
,
void
*
softmax_lse_d
,
...
...
@@ -73,10 +73,12 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
params
.
q_head_stride_in_elts
=
q
.
stride
(
1
);
params
.
k_head_stride_in_elts
=
k
.
stride
(
1
);
params
.
v_head_stride_in_elts
=
v
.
stride
(
1
);
params
.
o_ptr
=
o
_packed_d
;
params
.
o_row_stride_in_elts
=
h
*
d
;
params
.
o_head_stride_in_elts
=
d
;
params
.
o_ptr
=
o
ut
.
data_ptr
()
;
params
.
o_row_stride_in_elts
=
out
.
stride
(
0
)
;
params
.
o_head_stride_in_elts
=
out
.
stride
(
1
)
;
params
.
o_tmp_ptr
=
o_tmp_d
;
params
.
o_tmp_row_stride_in_elts
=
h
*
d
;
params
.
o_tmp_head_stride_in_elts
=
d
;
params
.
cu_seqlens_q
=
static_cast
<
int
*>
(
cu_seqlens_q_d
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
cu_seqlens_k_d
);
...
...
@@ -127,12 +129,12 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
const
at
::
Tensor
q
,
const
at
::
Tensor
k
,
const
at
::
Tensor
v
,
const
at
::
Tensor
out
,
at
::
Tensor
dq
,
at
::
Tensor
dk
,
at
::
Tensor
dv
,
void
*
cu_seqlens_q_d
,
void
*
cu_seqlens_k_d
,
void
*
o_packed_d
,
void
*
dq_tmp_d
,
void
*
do_packed_d
,
void
*
softmax_lse_d
,
...
...
@@ -143,10 +145,9 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
set_params_fprop
(
params
,
b
,
seqlen_q
,
seqlen_k
,
h
,
d
,
q
,
k
,
v
,
q
,
k
,
v
,
out
,
cu_seqlens_q_d
,
cu_seqlens_k_d
,
o_packed_d
,
dq_tmp_d
,
// Reusing the o_tmp_ptr variable to store dq_tmp
nullptr
,
softmax_lse_d
,
...
...
@@ -174,6 +175,7 @@ std::vector<at::Tensor>
mha_fwd
(
const
at
::
Tensor
&
q
,
// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
k
,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
v
,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
at
::
Tensor
&
out
,
// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
int
max_seqlen_q_
,
...
...
@@ -198,18 +200,21 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
is_sm8x
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
cu_seqlens_q
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cu_seqlens_k
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
q
.
is_cuda
());
TORCH_CHECK
(
k
.
is_cuda
());
TORCH_CHECK
(
v
.
is_cuda
());
TORCH_CHECK
(
out
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens_q
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens_k
.
is_cuda
());
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
cu_seqlens_k
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens_k
.
is_contiguous
());
...
...
@@ -226,6 +231,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads
,
head_size
);
CHECK_SHAPE
(
v
,
total_k
,
num_heads
,
head_size
);
CHECK_SHAPE
(
out
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
...
...
@@ -242,7 +248,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
auto
opts
=
q
.
options
();
auto
o
=
torch
::
empty
({
total_q
,
num_heads
,
head_size
},
opts
);
//
auto o = torch::empty({ total_q, num_heads, head_size }, opts);
at
::
Tensor
o_tmp
;
if
(
loop
)
{
o_tmp
=
torch
::
empty
({
total_q
,
num_heads
,
head_size
},
opts
.
dtype
(
at
::
kFloat
));
}
...
...
@@ -254,7 +260,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if
(
return_softmax
)
{
s
=
torch
::
empty
({
batch_size
,
num_heads
,
max_seqlen_q
,
max_seqlen_k
},
opts
);
}
if
(
zero_tensors
)
{
o
.
zero_
();
o
ut
.
zero_
();
softmax_lse
.
fill_
(
-
std
::
numeric_limits
<
float
>::
infinity
());
if
(
return_softmax
)
{
s
.
zero_
();}
}
...
...
@@ -268,10 +274,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
max_seqlen_k
,
num_heads
,
head_size
,
q
,
k
,
v
,
q
,
k
,
v
,
out
,
cu_seqlens_q
.
data_ptr
(),
cu_seqlens_k
.
data_ptr
(),
o
.
data_ptr
(),
loop
?
o_tmp
.
data_ptr
()
:
nullptr
,
return_softmax
?
s
.
data_ptr
()
:
nullptr
,
softmax_lse
.
data_ptr
(),
...
...
@@ -293,7 +298,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
run_fmha_fp16_sm80
(
launch_params
,
/*configure=*/
false
);
std
::
vector
<
at
::
Tensor
>
result
=
{
o
,
softmax_lse
};
std
::
vector
<
at
::
Tensor
>
result
=
{
softmax_lse
};
if
(
return_softmax
)
{
result
.
push_back
(
s
);}
return
result
;
}
...
...
@@ -418,11 +423,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
max_seqlen_k
,
num_heads
,
head_size
,
q
,
k
,
v
,
q
,
k
,
v
,
out
,
dq
,
dk
,
dv
,
cu_seqlens_q
.
data_ptr
(),
cu_seqlens_k
.
data_ptr
(),
out
.
data_ptr
(),
loop
?
dq_tmp
.
data_ptr
()
:
nullptr
,
dout
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
...
...
@@ -541,10 +545,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
max_seqlen_k
,
num_heads
,
head_size
,
q
,
k
,
v
,
q
,
k
,
v
,
o
,
cu_seqlens_q
.
data_ptr
(),
cu_seqlens_k
.
data_ptr
(),
o
.
data_ptr
(),
loop
?
o_tmp
.
data_ptr
()
:
nullptr
,
return_softmax
?
s
.
data_ptr
()
:
nullptr
,
softmax_lse
.
data_ptr
(),
...
...
@@ -686,11 +689,10 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
max_seqlen_k
,
num_heads
,
head_size
,
q
,
k
,
v
,
q
,
k
,
v
,
out
,
dq
,
dk
,
dv
,
cu_seqlens_q
.
data_ptr
(),
cu_seqlens_k
.
data_ptr
(),
out
.
data_ptr
(),
loop
?
dq_tmp
.
data_ptr
()
:
nullptr
,
dout
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
...
...
csrc/flash_attn/src/fmha.h
View file @
5badfb78
...
...
@@ -81,6 +81,8 @@ struct FMHA_fprop_params : public Qkv_params {
// size_t o_stride_in_bytes;
uint32_t
o_row_stride_in_elts
;
uint32_t
o_head_stride_in_elts
;
uint32_t
o_tmp_row_stride_in_elts
;
uint32_t
o_tmp_head_stride_in_elts
;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
5badfb78
...
...
@@ -259,7 +259,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_
tmp_
row_stride_in_elts
,
params
.
o_
tmp_
head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
...
...
csrc/flash_attn/src/static_switch.h
View file @
5badfb78
...
...
@@ -22,4 +22,4 @@
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
}()
flash_attn/flash_attn_interface.py
View file @
5badfb78
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
flash_attn_cuda
...
...
@@ -14,11 +15,11 @@ def _get_block_size(device, head_dim, is_dropout):
return
256
if
(
torch
.
cuda
.
get_device_capability
(
device
)
==
(
8
,
0
)
and
not
is_dropout
)
else
128
def
_flash_attn_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
out
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
generator
=
None
):
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
generator
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
...
...
@@ -27,10 +28,11 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
):
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
generator
=
None
):
softmax_d
=
flash_attn_cuda
.
bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
None
)
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
generator
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dq
,
dk
,
dv
,
softmax_d
...
...
@@ -82,8 +84,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
q
,
kv
[:,
0
],
kv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
q
,
kv
[:,
0
],
kv
[:,
1
],
torch
.
empty_like
(
q
),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
ctx
.
save_for_backward
(
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -121,7 +123,7 @@ class FlashAttnFunc(torch.autograd.Function):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
torch
.
empty_like
(
q
),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
...
...
@@ -148,6 +150,85 @@ class FlashAttnFunc(torch.autograd.Function):
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnQKVPackedSplitFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
# Save rng_state because the backward pass will regenerate the dropout mask
if
dropout_p
>
0
:
rng_state0
=
torch
.
cuda
.
get_rng_state
()
generator1
=
torch
.
Generator
(
device
=
'cuda'
)
rng_state1
=
generator1
.
get_state
()
else
:
rng_state0
,
generator1
,
rng_state1
=
None
,
None
,
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
=
torch
.
empty_like
(
qkv
[:,
0
])
_
,
softmax_lse0
,
S_dmask0
=
_flash_attn_forward
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
cu_seqlens
[:
batch_size0
+
1
],
cu_seqlens
[:
batch_size0
+
1
],
max_seqlen0
,
max_seqlen0
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
s
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
s
):
_
,
softmax_lse1
,
S_dmask1
=
_flash_attn_forward
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
cu_seqlens
[
batch_size0
:],
cu_seqlens
[
batch_size0
:],
max_seqlen1
,
max_seqlen1
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
,
generator
=
generator1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
ctx
.
save_for_backward
(
qkv
,
out
,
softmax_lse0
,
softmax_lse1
,
cu_seqlens
,
rng_state0
,
rng_state1
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen0
=
max_seqlen0
ctx
.
max_seqlen1
=
max_seqlen1
ctx
.
batch_size0
=
batch_size0
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
if
not
return_softmax
:
return
out
else
:
max_seqlen_q
=
max
(
softmax_lse0
.
shape
[
2
],
softmax_lse1
.
shape
[
2
])
max_seqlen_k
=
max
(
S_dmask0
.
shape
[
3
],
S_dmask1
.
shape
[
3
])
softmax_lse
=
torch
.
cat
([
F
.
pad
(
softmax_lse0
,
(
0
,
max_seqlen_q
-
softmax_lse0
.
shape
[
2
])),
F
.
pad
(
softmax_lse1
,
(
0
,
max_seqlen_q
-
softmax_lse1
.
shape
[
2
]))],
dim
=
0
)
return
out
,
softmax_lse
,
S_dmask0
,
S_dmask1
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
qkv
,
out
,
softmax_lse0
,
softmax_lse1
,
cu_seqlens
,
rng_state0
,
rng_state1
=
ctx
.
saved_tensors
batch_size0
=
ctx
.
batch_size0
if
rng_state0
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state0
)
if
rng_state1
is
not
None
:
generator1
=
torch
.
Generator
(
device
=
'cuda'
)
generator1
.
set_state
(
rng_state1
)
else
:
generator1
=
None
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse0
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[:
batch_size0
+
1
],
cu_seqlens
[:
batch_size0
+
1
],
ctx
.
max_seqlen0
,
ctx
.
max_seqlen0
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
)
s
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
s
):
_flash_attn_backward
(
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse1
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[
batch_size0
:],
cu_seqlens
[
batch_size0
:],
ctx
.
max_seqlen1
,
ctx
.
max_seqlen1
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
generator
=
generator1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
if
rng_state0
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -243,6 +324,43 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_unpadded_qkvpacked_split_func
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.
dropout_p should be set to 0.0 during evaluation.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen0: int. Maximum sequence length in 1st part of the batch.
max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
batch_size0: int. Number of sequences in the 1st part of the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedSplitFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""For backward-compatibility only, will remove soon.
...
...
tests/test_flash_attn.py
View file @
5badfb78
...
...
@@ -8,6 +8,7 @@ import pytest
from
einops
import
rearrange
,
repeat
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_unpadded_qkvpacked_func
,
_get_block_size
,
flash_attn_unpadded_kvpacked_func
,
flash_attn_unpadded_func
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_split_func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
...
...
@@ -16,13 +17,19 @@ is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
'random'
):
assert
mode
in
[
'full'
,
'random'
,
'third'
]
assert
mode
in
[
'full'
,
'random'
,
'third'
,
'split'
]
if
mode
==
'full'
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
'random'
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
'third'
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
'split'
:
lengths0
=
torch
.
randint
(
min
(
128
,
max_seqlen
),
max_seqlen
+
1
,
(
batch_size
//
4
*
3
,
1
),
device
=
device
)
lengths1
=
torch
.
randint
(
min
(
max
(
1
,
max_seqlen
-
20
),
128
),
min
(
max_seqlen
,
128
)
+
1
,
(
batch_size
-
batch_size
//
4
*
3
,
1
),
device
=
device
)
lengths
=
torch
.
cat
([
lengths0
,
lengths1
],
dim
=
0
)
padding_mask
=
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
's -> b s'
,
b
=
batch_size
)
<
lengths
return
padding_mask
...
...
@@ -605,6 +612,104 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_unpadded_qkvpacked_split
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
nheads
=
4
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
*
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
'split'
)
batch_size0
=
batch_size
//
4
*
3
# this must match what's in generate_random_padding_mask
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad
,
cu_seqlens
,
max_seqlen0
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
x
,
Wqkv
,
nheads
,
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
)
max_seqlen1
=
128
output_unpad
,
sm_lse
,
S_dmask0
,
S_dmask1
=
flash_attn_unpadded_qkvpacked_split_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
output
=
output_pad_fn
(
output_unpad
)
S_dmask0_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask0
,
key_padding_mask
[:
batch_size0
],
key_padding_mask
[:
batch_size0
],
d
,
dropout_p
>
0.0
,
causal
=
causal
)
S_dmask1_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask1
,
key_padding_mask
[
batch_size0
:,
:
max_seqlen1
],
key_padding_mask
[
batch_size0
:,
:
max_seqlen1
],
d
,
dropout_p
>
0.0
,
causal
=
causal
)
padding
=
(
S_dmask0_converted
.
shape
[
-
1
]
-
S_dmask1_converted
.
shape
[
-
1
],
S_dmask0_converted
.
shape
[
-
2
]
-
S_dmask1_converted
.
shape
[
-
2
])
S_dmask_converted
=
torch
.
cat
([
S_dmask0_converted
,
F
.
pad
(
S_dmask1_converted
,
(
0
,
padding
[
0
],
0
,
padding
[
1
]))],
dim
=
0
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
).
item
()
output_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Actual dropout fraction:
{
dropout_fraction
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
dqkv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
qkv
,
g
)
dqkv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
qkv
,
g
)
print
(
f
'dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dqkv
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dqkv
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQKV mean diff:
{
(
dqkv
-
dqkv_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dqkv_pt
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dqkv_pt
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dqkv_pt
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQKV Pytorch mean diff:
{
(
dqkv_pt
-
dqkv_ref
).
abs
().
mean
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if
dropout_p
==
0.0
:
assert
dropout_mask
.
all
()
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment