Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dca6d89d
Commit
dca6d89d
authored
Jul 10, 2024
by
Tri Dao
Browse files
Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
parent
81e01efd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
5 deletions
+13
-5
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+4
-0
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+1
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+8
-4
No files found.
csrc/flash_attn/flash_api.cpp
View file @
dca6d89d
...
@@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
softcap
>
0.
f
)
{
TORCH_CHECK
(
p_dropout
==
0.
f
,
"Softcapping does not support dropout for now"
);
}
if
(
window_size_left
>=
seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_left
>=
seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
seqlen_k
)
{
window_size_right
=
-
1
;
}
if
(
window_size_right
>=
seqlen_k
)
{
window_size_right
=
-
1
;
}
...
@@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const
int
head_size_og
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
2
];
const
int
num_heads_k
=
paged_KV
?
k
.
size
(
2
)
:
k
.
size
(
1
);
const
int
num_heads_k
=
paged_KV
?
k
.
size
(
2
)
:
k
.
size
(
1
);
if
(
softcap
>
0.
f
)
{
TORCH_CHECK
(
p_dropout
==
0.
f
,
"Softcapping does not support dropout for now"
);
}
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
num_blocks
=
!
paged_KV
?
0
:
k
.
size
(
0
);
const
int
num_blocks
=
!
paged_KV
?
0
:
k
.
size
(
0
);
const
int
page_block_size
=
!
paged_KV
?
1
:
k
.
size
(
1
);
const
int
page_block_size
=
!
paged_KV
?
1
:
k
.
size
(
1
);
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
dca6d89d
...
@@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
ReturnSoftmaxConst
&&
Is_dropout
&&
!
Is_softcap
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
...
...
tests/test_flash_attn.py
View file @
dca6d89d
...
@@ -895,12 +895,14 @@ def test_flash_attn_output(
...
@@ -895,12 +895,14 @@ def test_flash_attn_output(
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
@@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
...
@@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment