Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33687a3f
Unverified
Commit
33687a3f
authored
May 24, 2023
by
Wang, Yi
Committed by
GitHub
May 24, 2023
Browse files
add GPTJ/bloom/llama/opt into model list and enhance the jit support (#23291)
Signed-off-by:
Wang, Yi A
<
yi.a.wang@intel.com
>
parent
003a0cf8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
42 deletions
+55
-42
examples/pytorch/text-generation/README.md
examples/pytorch/text-generation/README.md
+1
-1
examples/pytorch/text-generation/run_generation.py
examples/pytorch/text-generation/run_generation.py
+54
-41
No files found.
examples/pytorch/text-generation/README.md
View file @
33687a3f
...
@@ -18,7 +18,7 @@ limitations under the License.
...
@@ -18,7 +18,7 @@ limitations under the License.
Based on the script
[
`run_generation.py`
](
https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py
)
.
Based on the script
[
`run_generation.py`
](
https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py
)
.
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL.
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2,
GPTJ,
Transformer-XL, XLNet, CTRL
, BLOOM, LLAMA, OPT
.
A similar script is used for our official demo
[
Write With Transfomer
](
https://transformer.huggingface.co
)
, where you
A similar script is used for our official demo
[
Write With Transfomer
](
https://transformer.huggingface.co
)
, where you
can try out the different models available in the library.
can try out the different models available in the library.
...
...
examples/pytorch/text-generation/run_generation.py
View file @
33687a3f
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
import
argparse
import
argparse
import
inspect
import
logging
import
logging
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -26,13 +27,20 @@ import numpy as np
...
@@ -26,13 +27,20 @@ import numpy as np
import
torch
import
torch
from
transformers
import
(
from
transformers
import
(
AutoTokenizer
,
BloomForCausalLM
,
BloomTokenizerFast
,
CTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRLTokenizer
,
CTRLTokenizer
,
GenerationMixin
,
GenerationMixin
,
GPT2LMHeadModel
,
GPT2LMHeadModel
,
GPT2Tokenizer
,
GPT2Tokenizer
,
GPTJForCausalLM
,
LlamaForCausalLM
,
LlamaTokenizer
,
OpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
,
OpenAIGPTTokenizer
,
OPTForCausalLM
,
TransfoXLLMHeadModel
,
TransfoXLLMHeadModel
,
TransfoXLTokenizer
,
TransfoXLTokenizer
,
XLMTokenizer
,
XLMTokenizer
,
...
@@ -59,6 +67,10 @@ MODEL_CLASSES = {
...
@@ -59,6 +67,10 @@ MODEL_CLASSES = {
"xlnet"
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
"xlnet"
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
"transfo-xl"
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
"transfo-xl"
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
"xlm"
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
"xlm"
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
"gptj"
:
(
GPTJForCausalLM
,
AutoTokenizer
),
"bloom"
:
(
BloomForCausalLM
,
BloomTokenizerFast
),
"llama"
:
(
LlamaForCausalLM
,
LlamaTokenizer
),
"opt"
:
(
OPTForCausalLM
,
GPT2Tokenizer
),
}
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...
@@ -173,23 +185,26 @@ def sparse_model_config(model_config):
...
@@ -173,23 +185,26 @@ def sparse_model_config(model_config):
raise
ValueError
(
"Check the model config"
)
raise
ValueError
(
"Check the model config"
)
num_embedding_size_per_head
=
int
(
embedding_size
/
num_head
)
num_embedding_size_per_head
=
int
(
embedding_size
/
num_head
)
if
hasattr
(
model_config
,
"n_layer"
):
num_layer
=
model_config
.
n_layer
num_layer
=
model_config
.
n_layer
elif
hasattr
(
model_config
,
"num_hidden_layers"
):
num_layer
=
model_config
.
num_hidden_layers
else
:
raise
ValueError
(
"Number of hidden layers couldn't be determined from the model config"
)
return
num_layer
,
num_head
,
num_embedding_size_per_head
return
num_layer
,
num_head
,
num_embedding_size_per_head
def
prepare_jit_inputs
(
inputs
,
model
,
tokenizer
):
def
generate_past_key_values
(
model
,
batch_size
,
seq_len
):
num_batch
=
len
(
inputs
)
dummy_input
=
tokenizer
.
batch_encode_plus
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
)
num_block_layers
,
num_attention_heads
,
num_embedding_size_per_head
=
sparse_model_config
(
model
.
config
)
num_block_layers
,
num_attention_heads
,
num_embedding_size_per_head
=
sparse_model_config
(
model
.
config
)
if
model
.
config
.
model_type
==
"bloom"
:
if
model
.
config
.
model_type
==
"bloom"
:
past_key_values
=
tuple
(
past_key_values
=
tuple
(
(
(
torch
.
zeros
(
int
(
num_attention_heads
*
num_
batch
),
num_embedding_size_per_head
,
1
)
torch
.
empty
(
int
(
num_attention_heads
*
batch
_size
),
num_embedding_size_per_head
,
seq_len
)
.
to
(
model
.
config
.
torch_
dtype
)
.
to
(
model
.
dtype
)
.
to
(
model
.
device
),
.
to
(
model
.
device
),
torch
.
zeros
(
int
(
num_attention_heads
*
num_
batch
),
1
,
num_embedding_size_per_head
)
torch
.
empty
(
int
(
num_attention_heads
*
batch
_size
),
seq_len
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_
dtype
)
.
to
(
model
.
dtype
)
.
to
(
model
.
device
),
.
to
(
model
.
device
),
)
)
for
_
in
range
(
num_block_layers
)
for
_
in
range
(
num_block_layers
)
...
@@ -197,37 +212,34 @@ def prepare_jit_inputs(inputs, model, tokenizer):
...
@@ -197,37 +212,34 @@ def prepare_jit_inputs(inputs, model, tokenizer):
else
:
else
:
past_key_values
=
tuple
(
past_key_values
=
tuple
(
(
(
torch
.
zeros
(
num_batch
,
num_attention_heads
,
1
,
num_embedding_size_per_head
)
torch
.
empty
(
batch_size
,
num_attention_heads
,
seq_len
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_
dtype
)
.
to
(
model
.
dtype
)
.
to
(
model
.
device
),
.
to
(
model
.
device
),
torch
.
zeros
(
num_batch
,
num_attention_heads
,
1
,
num_embedding_size_per_head
)
torch
.
empty
(
batch_size
,
num_attention_heads
,
seq_len
,
num_embedding_size_per_head
)
.
to
(
model
.
config
.
torch_
dtype
)
.
to
(
model
.
dtype
)
.
to
(
model
.
device
),
.
to
(
model
.
device
),
)
)
for
_
in
range
(
num_block_layers
)
for
_
in
range
(
num_block_layers
)
)
)
return
past_key_values
def
prepare_jit_inputs
(
inputs
,
model
,
tokenizer
):
batch_size
=
len
(
inputs
)
dummy_input
=
tokenizer
.
batch_encode_plus
(
inputs
,
return_tensors
=
"pt"
)
dummy_input
=
dummy_input
.
to
(
model
.
device
)
if
model
.
config
.
use_cache
:
dummy_input
[
"past_key_values"
]
=
generate_past_key_values
(
model
,
batch_size
,
1
)
dummy_input
[
"attention_mask"
]
=
torch
.
cat
(
dummy_input
[
"attention_mask"
]
=
torch
.
cat
(
[
[
torch
.
zeros
(
dummy_input
[
"attention_mask"
].
shape
[
0
],
1
).
to
(
dummy_input
[
"attention_mask"
].
dtype
),
torch
.
zeros
(
dummy_input
[
"attention_mask"
].
shape
[
0
],
1
)
.
to
(
dummy_input
[
"attention_mask"
].
dtype
)
.
to
(
model
.
device
),
dummy_input
[
"attention_mask"
],
dummy_input
[
"attention_mask"
],
],
],
-
1
,
-
1
,
)
)
return
dummy_input
if
model
.
config
.
use_cache
:
jit_inputs
=
(
dummy_input
[
"input_ids"
].
to
(
model
.
device
),
past_key_values
,
dummy_input
[
"attention_mask"
].
to
(
model
.
device
),
)
else
:
jit_inputs
=
(
dummy_input
[
"input_ids"
].
to
(
model
.
device
),
dummy_input
[
"attention_mask"
].
to
(
model
.
device
),
)
return
jit_inputs
class
_ModelFallbackWrapper
(
GenerationMixin
):
class
_ModelFallbackWrapper
(
GenerationMixin
):
...
@@ -238,15 +250,13 @@ class _ModelFallbackWrapper(GenerationMixin):
...
@@ -238,15 +250,13 @@ class _ModelFallbackWrapper(GenerationMixin):
self
.
_default
=
default
self
.
_default
=
default
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
kwargs
[
"past_key_values"
]
is
None
:
if
kwargs
[
"past_key_values"
]
is
None
and
self
.
_default
.
config
.
use_cache
:
return
self
.
_default
(
*
args
,
**
kwargs
)
kwargs
[
"past_key_values"
]
=
generate_past_key_values
(
self
.
_default
,
kwargs
[
"input_ids"
].
shape
[
0
],
0
)
trace_graph_inputs
=
[]
kwargs
.
pop
(
"position_ids"
,
None
)
kwargs
.
pop
(
"position_ids"
,
None
)
for
k
,
v
in
kwargs
.
items
():
for
k
in
list
(
kwargs
.
keys
()):
if
v
is
not
None
and
not
isinstance
(
v
,
bool
):
if
kwargs
[
k
]
is
None
or
isinstance
(
kwargs
[
k
],
bool
):
trace_graph_inputs
.
append
(
v
)
kwargs
.
pop
(
k
)
trace_graph_inputs
=
tuple
(
trace_graph_inputs
)
outputs
=
self
.
_optimized
(
**
kwargs
)
outputs
=
self
.
_optimized
(
*
trace_graph_inputs
)
lm_logits
=
outputs
[
0
]
lm_logits
=
outputs
[
0
]
past_key_values
=
outputs
[
1
]
past_key_values
=
outputs
[
1
]
fixed_output
=
CausalLMOutputWithPast
(
fixed_output
=
CausalLMOutputWithPast
(
...
@@ -324,9 +334,7 @@ def main():
...
@@ -324,9 +334,7 @@ def main():
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
,
help
=
"Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--jit"
,
action
=
"store_true"
,
help
=
"Whether or not to use jit trace to accelerate inference"
)
"--jit"
,
type
=
bool
,
default
=
False
,
help
=
"Whether or not to use jit trace to accelerate inference"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
...
@@ -351,8 +359,8 @@ def main():
...
@@ -351,8 +359,8 @@ def main():
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
max_seq_length
=
getattr
(
model
.
config
,
"max_position_embeddings"
,
0
)
args
.
length
=
adjust_length_to_model
(
args
.
length
,
max_sequence_length
=
m
odel
.
config
.
max_position_embeddings
)
args
.
length
=
adjust_length_to_model
(
args
.
length
,
max_sequence_length
=
m
ax_seq_length
)
logger
.
info
(
args
)
logger
.
info
(
args
)
prompt_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
prompt_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
...
@@ -382,10 +390,15 @@ def main():
...
@@ -382,10 +390,15 @@ def main():
input_ids
=
encoded_prompt
input_ids
=
encoded_prompt
if
args
.
jit
:
if
args
.
jit
:
jit_input_texts
=
[
"jit"
]
jit_input_texts
=
[
"
enable
jit"
]
jit_inputs
=
prepare_jit_inputs
(
jit_input_texts
,
model
,
tokenizer
)
jit_inputs
=
prepare_jit_inputs
(
jit_input_texts
,
model
,
tokenizer
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
model
.
config
.
return_dict
=
False
model
.
config
.
return_dict
=
False
if
hasattr
(
model
,
"forward"
):
sig
=
inspect
.
signature
(
model
.
forward
)
else
:
sig
=
inspect
.
signature
(
model
.
__call__
)
jit_inputs
=
tuple
(
jit_inputs
[
key
]
for
key
in
sig
.
parameters
if
jit_inputs
.
get
(
key
,
None
)
is
not
None
)
traced_model
=
torch
.
jit
.
trace
(
model
,
jit_inputs
,
strict
=
False
)
traced_model
=
torch
.
jit
.
trace
(
model
,
jit_inputs
,
strict
=
False
)
traced_model
=
torch
.
jit
.
freeze
(
traced_model
.
eval
())
traced_model
=
torch
.
jit
.
freeze
(
traced_model
.
eval
())
traced_model
(
*
jit_inputs
)
traced_model
(
*
jit_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment