Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f68d41ec
Commit
f68d41ec
authored
Jan 17, 2023
by
Tri Dao
Browse files
[Gen] Add OPT to generation test
parent
88173a1a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
11 deletions
+164
-11
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+8
-4
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+25
-3
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+130
-4
tests/models/test_opt.py
tests/models/test_opt.py
+1
-0
No files found.
flash_attn/utils/generation.py
View file @
f68d41ec
...
...
@@ -71,7 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
...
@@ -104,14 +105,15 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores
=
[]
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
timing
:
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
if
timing
:
start
=
time
.
time
()
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
...
...
@@ -127,11 +129,13 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
break
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
if
timing
:
torch
.
cuda
.
synchronize
()
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
print
(
f
'Decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
'
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
...
flash_attn/utils/pretrained.py
View file @
f68d41ec
import
torch
from
transformers.utils
import
WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
from
transformers.utils
import
WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
from
transformers.utils
import
is_remote_url
from
transformers.modeling_utils
import
load_state_dict
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
is_sharded
=
False
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file
,
sharded_metadata
=
get_checkpoint_shard_files
(
model_name
,
resolved_archive_file
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
device
))
else
:
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
tests/models/test_gpt_generation.py
View file @
f68d41ec
import
os
import
re
import
time
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers
import
GPT2Config
,
GPT2Tokenizer
,
OPTConfig
,
AutoTokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
...
...
@@ -22,7 +26,7 @@ from flash_attn.utils.distributed import all_gather_raw
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
_gpt2
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
...
...
@@ -49,13 +53,14 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and
"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
...
...
@@ -106,3 +111,124 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
def
test_greedy_decode_opt
(
model_name
):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print
(
f
'
\n
MODEL:
{
model_name
}
'
)
verbose
=
False
dtype
=
torch
.
float16
device
=
'cuda'
rtol
,
atol
=
3e-3
,
3e-1
fused_ft_kernel
=
True
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
# OPT tokenizer requires use_fast=False
# https://huggingface.co/docs/transformers/model_doc/opt
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
eos_token_id
=
tokenizer
.
eos_token_id
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
if
eos_token_id
is
not
None
and
(
sequences
[
-
1
]
==
eos_token_id
).
all
():
break
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
'Without CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
'With CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
print
(
out_cg
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
del
model
model_hf
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_hf
model_ref
=
OPTForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
.
eval
()
print
(
"HF fp32"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_ref
print
(
tokenizer
.
batch_decode
(
out_ref
.
sequences
.
tolist
()))
if
verbose
:
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
tests/models/test_opt.py
View file @
f68d41ec
...
...
@@ -35,6 +35,7 @@ def test_opt_optimized(model_name):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment