Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f68d41ec
Commit
f68d41ec
authored
Jan 17, 2023
by
Tri Dao
Browse files
[Gen] Add OPT to generation test
parent
88173a1a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
11 deletions
+164
-11
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+8
-4
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+25
-3
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+130
-4
tests/models/test_opt.py
tests/models/test_opt.py
+1
-0
No files found.
flash_attn/utils/generation.py
View file @
f68d41ec
...
@@ -71,7 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -71,7 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
@@ -104,14 +105,15 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -104,14 +105,15 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores
=
[]
scores
=
[]
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
timing
:
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
if
vocab_size
is
not
None
:
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
if
timing
:
start
=
time
.
time
()
while
True
:
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
...
@@ -127,11 +129,13 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -127,11 +129,13 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
inference_params
.
sequence_len_offset
+=
1
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
break
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
break
if
timing
:
if
timing
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
print
(
f
'Decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
'
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
...
flash_attn/utils/pretrained.py
View file @
f68d41ec
import
torch
import
torch
from
transformers.utils
import
WEIGHTS_NAME
from
transformers.utils
import
WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
from
transformers.utils.hub
import
cached_file
from
transformers.utils
import
is_remote_url
from
transformers.modeling_utils
import
load_state_dict
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
is_sharded
=
False
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file
,
sharded_metadata
=
get_checkpoint_shard_files
(
model_name
,
resolved_archive_file
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
device
))
else
:
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
return
state_dict
tests/models/test_gpt_generation.py
View file @
f68d41ec
import
os
import
os
import
re
import
re
import
time
import
torch
import
torch
import
pytest
import
pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers
import
GPT2Config
,
GPT2Tokenizer
,
OPTConfig
,
AutoTokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
...
@@ -22,7 +26,7 @@ from flash_attn.utils.distributed import all_gather_raw
...
@@ -22,7 +26,7 @@ from flash_attn.utils.distributed import all_gather_raw
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
_gpt2
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
...
@@ -49,13 +53,14 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -49,13 +53,14 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if
not
rotary
:
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_ref
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and
"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
max_length
=
30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
...
@@ -106,3 +111,124 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -106,3 +111,124 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
def
test_greedy_decode_opt
(
model_name
):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print
(
f
'
\n
MODEL:
{
model_name
}
'
)
verbose
=
False
dtype
=
torch
.
float16
device
=
'cuda'
rtol
,
atol
=
3e-3
,
3e-1
fused_ft_kernel
=
True
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
# OPT tokenizer requires use_fast=False
# https://huggingface.co/docs/transformers/model_doc/opt
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
eos_token_id
=
tokenizer
.
eos_token_id
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
if
eos_token_id
is
not
None
and
(
sequences
[
-
1
]
==
eos_token_id
).
all
():
break
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
'Without CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
'With CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
print
(
out_cg
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
del
model
model_hf
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_hf
model_ref
=
OPTForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
.
eval
()
print
(
"HF fp32"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_ref
print
(
tokenizer
.
batch_decode
(
out_ref
.
sequences
.
tolist
()))
if
verbose
:
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
tests/models/test_opt.py
View file @
f68d41ec
...
@@ -35,6 +35,7 @@ def test_opt_optimized(model_name):
...
@@ -35,6 +35,7 @@ def test_opt_optimized(model_name):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
# Only prenorm supports residual_in_fp32
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment