Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8472a224
Unverified
Commit
8472a224
authored
Mar 22, 2023
by
jiqing-feng
Committed by
GitHub
Mar 22, 2023
Browse files
Enable traced model for text-generation task (#22265)
parent
0558914d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
145 additions
and
0 deletions
+145
-0
examples/pytorch/text-generation/run_generation.py
examples/pytorch/text-generation/run_generation.py
+145
-0
No files found.
examples/pytorch/text-generation/run_generation.py
View file @
8472a224
...
...
@@ -20,6 +20,7 @@
import
argparse
import
logging
from
typing
import
Tuple
import
numpy
as
np
import
torch
...
...
@@ -27,6 +28,7 @@ import torch
from
transformers
import
(
CTRLLMHeadModel
,
CTRLTokenizer
,
GenerationMixin
,
GPT2LMHeadModel
,
GPT2Tokenizer
,
OpenAIGPTLMHeadModel
,
...
...
@@ -38,6 +40,7 @@ from transformers import (
XLNetLMHeadModel
,
XLNetTokenizer
,
)
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
logging
.
basicConfig
(
...
...
@@ -151,6 +154,131 @@ def adjust_length_to_model(length, max_sequence_length):
return
length
def
sparse_model_config
(
model_config
):
embedding_size
=
None
if
hasattr
(
model_config
,
"hidden_size"
):
embedding_size
=
model_config
.
hidden_size
elif
hasattr
(
model_config
,
"n_embed"
):
embedding_size
=
model_config
.
n_embed
elif
hasattr
(
model_config
,
"n_embd"
):
embedding_size
=
model_config
.
n_embd
num_head
=
None
if
hasattr
(
model_config
,
"num_attention_heads"
):
num_head
=
model_config
.
num_attention_heads
elif
hasattr
(
model_config
,
"n_head"
):
num_head
=
model_config
.
n_head
if
embedding_size
is
None
or
num_head
is
None
or
num_head
==
0
:
raise
ValueError
(
"Check the model config"
)
num_embedding_size_per_head
=
int
(
embedding_size
/
num_head
)
num_layer
=
model_config
.
n_layer
return
num_layer
,
num_head
,
num_embedding_size_per_head
def
prepare_jit_inputs
(
inputs
,
model
,
tokenizer
):
num_batch
=
len
(
inputs
)
dummy_input
=
tokenizer
.
batch_encode_plus
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
)
num_block_layers
,
num_attention_heads
,
num_embedding_size_per_head
=
sparse_model_config
(
model
.
config
)
if
model
.
config
.
model_type
==
"bloom"
:
past_key_values
=
tuple
(
(
torch
.
zeros
(
int
(
num_attention_heads
*
num_batch
),
num_embedding_size_per_head
,
1
)
.
to
(
model
.
config
.
torch_dtype
)
.
to
(
model
.
device
),
torch
.
zeros
(
int
(
num_attention_heads
*
num_batch
),
1
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_dtype
)
.
to
(
model
.
device
),
)
for
_
in
range
(
num_block_layers
)
)
else
:
past_key_values
=
tuple
(
(
torch
.
zeros
(
num_batch
,
num_attention_heads
,
1
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_dtype
)
.
to
(
model
.
device
),
torch
.
zeros
(
num_batch
,
num_attention_heads
,
1
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_dtype
)
.
to
(
model
.
device
),
)
for
_
in
range
(
num_block_layers
)
)
dummy_input
[
"attention_mask"
]
=
torch
.
cat
(
[
torch
.
zeros
(
dummy_input
[
"attention_mask"
].
shape
[
0
],
1
).
to
(
dummy_input
[
"attention_mask"
].
dtype
),
dummy_input
[
"attention_mask"
],
],
-
1
,
)
if
model
.
config
.
use_cache
:
jit_inputs
=
(
dummy_input
[
"input_ids"
].
to
(
model
.
device
),
past_key_values
,
dummy_input
[
"attention_mask"
].
to
(
model
.
device
),
)
else
:
jit_inputs
=
(
dummy_input
[
"input_ids"
].
to
(
model
.
device
),
dummy_input
[
"attention_mask"
].
to
(
model
.
device
),
)
return
jit_inputs
class
_ModelFallbackWrapper
(
GenerationMixin
):
__slots__
=
(
"_optimized"
,
"_default"
)
def
__init__
(
self
,
optimized
,
default
):
self
.
_optimized
=
optimized
self
.
_default
=
default
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
kwargs
[
"past_key_values"
]
is
None
:
return
self
.
_default
(
*
args
,
**
kwargs
)
trace_graph_inputs
=
[]
kwargs
.
pop
(
"position_ids"
,
None
)
for
k
,
v
in
kwargs
.
items
():
if
v
is
not
None
and
not
isinstance
(
v
,
bool
):
trace_graph_inputs
.
append
(
v
)
trace_graph_inputs
=
tuple
(
trace_graph_inputs
)
outputs
=
self
.
_optimized
(
*
trace_graph_inputs
)
lm_logits
=
outputs
[
0
]
past_key_values
=
outputs
[
1
]
fixed_output
=
CausalLMOutputWithPast
(
loss
=
None
,
logits
=
lm_logits
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
return
fixed_output
def
__getattr__
(
self
,
item
):
return
getattr
(
self
.
_default
,
item
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
use_cache
=
None
,
**
kwargs
):
return
self
.
_default
.
prepare_inputs_for_generation
(
input_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
**
kwargs
)
def
_reorder_cache
(
self
,
past_key_values
:
Tuple
[
Tuple
[
torch
.
Tensor
]],
beam_idx
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return
self
.
_default
.
_reorder_cache
(
past_key_values
,
beam_idx
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
...
...
@@ -196,6 +324,9 @@ def main():
action
=
"store_true"
,
help
=
"Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
,
)
parser
.
add_argument
(
"--jit"
,
type
=
bool
,
default
=
False
,
help
=
"Whether or not to use jit trace to accelerate inference"
)
args
=
parser
.
parse_args
()
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
...
...
@@ -213,6 +344,8 @@ def main():
raise
KeyError
(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
if
tokenizer
.
pad_token
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
args
.
device
)
...
...
@@ -248,6 +381,18 @@ def main():
else
:
input_ids
=
encoded_prompt
if
args
.
jit
:
jit_input_texts
=
[
"jit"
]
jit_inputs
=
prepare_jit_inputs
(
jit_input_texts
,
model
,
tokenizer
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
model
.
config
.
return_dict
=
False
traced_model
=
torch
.
jit
.
trace
(
model
,
jit_inputs
,
strict
=
False
)
traced_model
=
torch
.
jit
.
freeze
(
traced_model
.
eval
())
traced_model
(
*
jit_inputs
)
traced_model
(
*
jit_inputs
)
model
=
_ModelFallbackWrapper
(
traced_model
,
model
)
output_sequences
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
args
.
length
+
len
(
encoded_prompt
[
0
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment