Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
605655bc
"official/projects/simclr/multitask_train.py" did not exist on "7347cf7f202ac0910effe31b9d5abbffcef5e0a9"
Commit
605655bc
authored
Apr 14, 2023
by
Tri Dao
Browse files
[Gen] Fix FT kernel when using CG
parent
dceb2687
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
12 deletions
+107
-12
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+4
-2
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+12
-9
tests/models/test_gptj.py
tests/models/test_gptj.py
+91
-1
No files found.
flash_attn/modules/mha.py
View file @
605655bc
...
...
@@ -495,7 +495,8 @@ class MHA(nn.Module):
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
not
self
.
rotary_emb
.
interleaved
# neox_rotary_style
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
...
...
@@ -609,7 +610,8 @@ class ParallelMHA(nn.Module):
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
not
self
.
rotary_emb
.
interleaved
# neox_rotary_style
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
if
seqlen
is
None
:
...
...
flash_attn/utils/generation.py
View file @
605655bc
...
...
@@ -190,9 +190,9 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
return
0
if
seqlen
<
32
else
(
1
if
seqlen
<
2048
else
2
)
def
seqlen_type_to_seqlen
(
seqlen_type
:
int
)
->
int
:
def
seqlen_type_to_
max_
seqlen
(
seqlen_type
:
int
)
->
int
:
assert
seqlen_type
in
[
0
,
1
,
2
]
return
1
if
seqlen_type
==
0
else
(
3
2
if
seqlen_type
==
1
else
2
048
)
return
32
if
seqlen_type
==
0
else
(
2
048
if
seqlen_type
==
1
else
2
**
32
)
@
dataclass
...
...
@@ -239,9 +239,9 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
if
s_type
not
in
cache
.
callables
:
seqlen
=
min
(
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
)),
max_seqlen
)
max_
seqlen
_
=
min
(
max
(
seqlen_og
,
seqlen_type_to_
max_
seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
,
model
,
cache
.
inference_params
,
batch_size
,
max_
seqlen
_
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
)
...
...
@@ -249,17 +249,19 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
return
cache
.
callables
[
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
sequence_len
gth
_offset
=
0
# Reset so it's not confusing
cache
.
inference_params
.
sequence_len_offset
=
0
# Reset so it's not confusing
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
,
n_warmups
=
2
):
assert
max_seqlen
>=
seqlen_og
def
capture_graph
(
model
,
inference_params
,
batch_size
,
max_seqlen
,
mempool
=
None
,
n_warmups
=
2
):
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
inference_params
.
lengths_per_sample
[:]
=
seqlen_og
sequence_len_offset_og
=
inference_params
.
sequence_len_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params
.
sequence_len_offset
=
max_seqlen
-
1
inference_params
.
lengths_per_sample
[:]
=
max_seqlen
-
1
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
...
...
@@ -289,4 +291,5 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
graph
.
replay
()
return
logits
inference_params
.
sequence_len_offset
=
sequence_len_offset_og
return
run
tests/models/test_gptj.py
View file @
605655bc
import
r
e
import
tim
e
import
torch
import
pytest
...
...
@@ -9,6 +9,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
,
gptj_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
...
...
@@ -79,3 +80,92 @@ def test_gptj_optimized(model_name):
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_generation
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
config
=
gptj_config_to_gpt2_config
(
GPTJConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
True
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
model_hf
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_hf
model_ref
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
{
""
:
device
})
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
):
-
1
]
del
model_ref
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
print
(
'Without CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
'With CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
with
torch
.
no_grad
():
logits_parallel
=
model
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
):
-
1
]
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment