Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
fd20f16a
Commit
fd20f16a
authored
Sep 05, 2023
by
Tri Dao
Browse files
Support cache_seqlens being integer
parent
913922ca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
16 deletions
+39
-16
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+9
-3
tests/models/test_gpt.py
tests/models/test_gpt.py
+30
-13
No files found.
flash_attn/flash_attn_interface.py
View file @
fd20f16a
from
typing
import
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
...
...
@@ -799,7 +800,7 @@ def flash_attn_with_kvcache(
v_cache
,
k
=
None
,
v
=
None
,
cache_seqlens
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
num_splits
=
0
,
...
...
@@ -840,7 +841,8 @@ def flash_attn_with_kvcache(
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
...
...
@@ -858,6 +860,10 @@ def flash_attn_with_kvcache(
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
cache_seqlens
is
not
None
and
isinstance
(
cache_seqlens
,
int
):
cache_seqlens
=
torch
.
full
(
(
k_cache
.
shape
[
0
],),
cache_seqlens
,
dtype
=
torch
.
int32
,
device
=
k_cache
.
device
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
)
...
...
tests/models/test_gpt.py
View file @
fd20f16a
...
...
@@ -3,7 +3,12 @@ import re
import
pytest
import
torch
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
,
shard_state_dict_tp
,
combine_state_dicts_tp
from
flash_attn.models.gpt
import
(
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
,
shard_state_dict_tp
,
combine_state_dicts_tp
,
)
from
flash_attn.utils.generation
import
InferenceParams
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
,
GPT2Tokenizer
...
...
@@ -130,9 +135,9 @@ def test_gpt2_optimized(model_name):
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [
Tru
e])
# @pytest.mark.parametrize('fused_ft_kernel', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Fals
e])
# @pytest.mark.parametrize('optimized', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
...
...
@@ -204,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
if
fused_ft_kernel
or
config
.
use_flash_attn
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
@@ -263,7 +268,6 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
teacher_outputs
=
teacher_outputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -277,8 +281,9 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"block"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
rotary
,
seqlen
,
maxlen
):
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype
=
torch
.
float16
device
=
"cuda"
...
...
@@ -308,8 +313,17 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
)
assert
torch
.
equal
(
logits
,
logits_cg
)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
...
...
@@ -446,11 +460,14 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
print
(
tokenizer
.
batch_decode
(
out_og
.
sequences
))
@
pytest
.
mark
.
parametrize
(
"n_heads_q_kv"
,
[
(
8
,
8
),
# Regular attention
(
8
,
4
),
# GQA
(
8
,
2
),
# MQA
])
@
pytest
.
mark
.
parametrize
(
"n_heads_q_kv"
,
[
(
8
,
8
),
# Regular attention
(
8
,
4
),
# GQA
(
8
,
2
),
# MQA
],
)
def
test_gpt2_shard_unshard
(
n_heads_q_kv
):
world_size
=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment