Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8a733cbd
Commit
8a733cbd
authored
Sep 10, 2023
by
Tri Dao
Browse files
[Gen] Fix calling update_graph_cache in tests
parent
4c91621a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
29 additions
and
20 deletions
+29
-20
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+4
-8
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+2
-0
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+1
-1
tests/models/test_falcon.py
tests/models/test_falcon.py
+6
-2
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+0
-1
tests/models/test_gptj.py
tests/models/test_gptj.py
+7
-5
tests/models/test_llama.py
tests/models/test_llama.py
+6
-2
tests/models/test_opt.py
tests/models/test_opt.py
+3
-1
No files found.
flash_attn/modules/mha.py
View file @
8a733cbd
...
...
@@ -659,8 +659,7 @@ class MHA(nn.Module):
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv
=
qkv
.
reshape
(
batch
,
seqlen
,
3
,
self
.
num_heads
,
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
...
...
@@ -700,10 +699,8 @@ class MHA(nn.Module):
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q
=
q
.
reshape
(
batch
,
seqlen
,
-
1
,
self
.
head_dim
)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv
=
kv
.
reshape
(
batch
,
seqlen
,
2
,
-
1
,
self
.
head_dim
)
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
...
...
@@ -731,8 +728,7 @@ class MHA(nn.Module):
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out
=
self
.
out_proj
(
context
.
reshape
(
batch
,
seqlen
,
-
1
))
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
...
flash_attn/ops/triton/rotary.py
View file @
8a733cbd
# Copyright (c) 2023, Tri Dao.
from
typing
import
Optional
,
Union
import
torch
...
...
tests/models/test_baichuan.py
View file @
8a733cbd
...
...
@@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
False
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
...
...
tests/models/test_falcon.py
View file @
8a733cbd
...
...
@@ -253,7 +253,9 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
@@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
...
...
tests/models/test_gpt_neox.py
View file @
8a733cbd
...
...
@@ -6,7 +6,6 @@ import pytest
import
torch
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt_neox
import
gpt_neox_config_to_gpt2_config
,
remap_state_dict_hf_gpt_neox
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
AutoTokenizer
,
GPTNeoXConfig
from
transformers.models.gpt_neox.modeling_gpt_neox
import
GPTNeoXForCausalLM
...
...
tests/models/test_gptj.py
View file @
8a733cbd
...
...
@@ -83,8 +83,9 @@ def test_gptj_optimized(model_name):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_generation
(
model_name
):
def
test_gptj_generation
(
model_name
,
fused_ft_kernel
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
...
...
@@ -140,8 +141,7 @@ def test_gptj_generation(model_name):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -152,14 +152,16 @@ def test_gptj_generation(model_name):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
tests/models/test_llama.py
View file @
8a733cbd
...
...
@@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
@@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
...
...
tests/models/test_opt.py
View file @
8a733cbd
...
...
@@ -168,7 +168,9 @@ def test_opt_generation(model_name):
if
fused_ft_kernel
:
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment