Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e4cf8bb
Unverified
Commit
1e4cf8bb
authored
Feb 07, 2023
by
Joao Gante
Committed by
GitHub
Feb 07, 2023
Browse files
Generate: TF can now generate from embeddings in encoder-decoder models (#21475)
parent
12eb528b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
196 deletions
+183
-196
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+91
-19
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+90
-0
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-177
No files found.
src/transformers/generation/tf_utils.py
View file @
1e4cf8bb
...
@@ -664,9 +664,11 @@ class TFGenerationMixin:
...
@@ -664,9 +664,11 @@ class TFGenerationMixin:
)
)
# 4. Define model inputs
# 4. Define model inputs
input_ids
=
self
.
_prepare_model_inputs
(
input_ids
,
generation_config
.
bos_token_id
)
inputs_tensor
,
model_input_name
,
model_kwargs
=
self
.
_prepare_model_inputs
(
input_ids
,
generation_config
.
bos_token_id
,
model_kwargs
)
# inputs_ids now has to be defined and cannot be None anymore
# inputs_ids now has to be defined and cannot be None anymore
batch_size
=
shape_list
(
input
_ids
)[
0
]
batch_size
=
shape_list
(
input
s_tensor
)[
0
]
# 5. Prepare other model kwargs
# 5. Prepare other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
...
@@ -678,23 +680,26 @@ class TFGenerationMixin:
...
@@ -678,23 +680,26 @@ class TFGenerationMixin:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
input
_ids
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
input
s_tensor
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
)
)
# decoder-only models should use left-padding for generation
# decoder-only models should use left-padding for generation
if
not
self
.
config
.
is_encoder_decoder
:
if
not
self
.
config
.
is_encoder_decoder
:
if
generation_config
.
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
if
generation_config
.
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
input
_ids
[:,
-
1
]
==
generation_config
.
pad_token_id
input
s_tensor
[:,
-
1
]
==
generation_config
.
pad_token_id
):
):
logger
.
warning
(
logger
.
warning
(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
)
if
self
.
config
.
is_encoder_decoder
and
"encoder_outputs"
not
in
model_kwargs
:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
)
# 6. Prepare model inputs which will be used for auto-regressive generation
# 6. Prepare model inputs which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
batch_size
,
...
@@ -702,6 +707,9 @@ class TFGenerationMixin:
...
@@ -702,6 +707,9 @@ class TFGenerationMixin:
bos_token_id
=
generation_config
.
bos_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
)
)
else
:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids
=
inputs_tensor
# 7. Prepare `max_length` depending on other stopping criteria.
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
...
@@ -924,7 +932,9 @@ class TFGenerationMixin:
...
@@ -924,7 +932,9 @@ class TFGenerationMixin:
else
:
else
:
return
tf
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
tf
.
int32
)
return
tf
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
tf
.
int32
)
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
)
->
Dict
[
str
,
Any
]:
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
# get encoder and store encoder outputs
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
...
@@ -938,7 +948,9 @@ class TFGenerationMixin:
...
@@ -938,7 +948,9 @@ class TFGenerationMixin:
# vision models don't use `attention_mask`.
# vision models don't use `attention_mask`.
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
self
.
main_input_name
]
=
inputs_tensor
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
encoder_kwargs
[
self
.
main_input_name
]
=
None
encoder_outputs
=
encoder
(
**
encoder_kwargs
)
encoder_outputs
=
encoder
(
**
encoder_kwargs
)
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
...
@@ -1007,19 +1019,79 @@ class TFGenerationMixin:
...
@@ -1007,19 +1019,79 @@ class TFGenerationMixin:
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
def
_prepare_model_inputs
(
self
,
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
):
def
_prepare_model_inputs
(
# TODO(Patrick) - adapt this function when making `generate` more flexible
self
,
# for all kinds of input types
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
if
inputs
is
None
:
bos_token_id
:
Optional
[
int
]
=
None
,
# if no `inputs` are passed create prompt of size (1,1) filled with BOS token
model_kwargs
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
if
not
isinstance
(
bos_token_id
,
int
)
or
bos_token_id
<
0
:
)
->
Tuple
[
tf
.
Tensor
,
Optional
[
str
],
Dict
[
str
,
tf
.
Tensor
]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if
(
self
.
config
.
is_encoder_decoder
and
hasattr
(
self
,
"encoder"
)
and
hasattr
(
self
.
encoder
,
"main_input_name"
)
and
self
.
encoder
.
main_input_name
!=
self
.
main_input_name
):
input_name
=
self
.
encoder
.
main_input_name
else
:
input_name
=
self
.
main_input_name
model_kwargs
=
{
k
:
v
for
k
,
v
in
model_kwargs
.
items
()
if
v
is
not
None
or
k
!=
input_name
}
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg
=
model_kwargs
.
pop
(
input_name
,
None
)
if
inputs_kwarg
is
not
None
and
inputs
is
not
None
:
raise
ValueError
(
f
"`inputs`:
{
inputs
}
` were passed alongside
{
input_name
}
which is not allowed."
f
"Make sure to either pass
{
inputs
}
or
{
input_name
}
=..."
)
elif
inputs_kwarg
is
not
None
:
inputs
=
inputs_kwarg
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if
input_name
==
"input_ids"
and
"inputs_embeds"
in
model_kwargs
:
if
not
self
.
config
.
is_encoder_decoder
:
has_inputs_embeds_forwarding
=
"inputs_embeds"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
()
)
if
not
has_inputs_embeds_forwarding
:
raise
ValueError
(
raise
ValueError
(
"you should either supply a context to complete as `input_ids` input "
f
"You passed `inputs_embeds` to `.generate()`, but the model class
{
self
.
__class__
.
__name__
}
"
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
)
return
tf
.
cast
(
tf
.
fill
((
1
,
1
),
bos_token_id
),
dtype
=
tf
.
int32
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
return
inputs
return
inputs
,
input_name
,
model_kwargs
def
_prepare_input_ids_for_generation
(
self
,
bos_token_id
:
Optional
[
int
],
encoder_outputs
:
Optional
[
ModelOutput
]
)
->
tf
.
Tensor
:
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
return
tf
.
ones
(
shape
,
dtype
=
tf
.
int32
)
*
-
100
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
tf
.
ones
((
1
,
1
),
dtype
=
tf
.
int32
)
*
bos_token_id
@
staticmethod
@
staticmethod
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
...
...
tests/generation/test_framework_agnostic.py
View file @
1e4cf8bb
...
@@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
...
@@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.testing_utils
import
torch_device
class
GenerationIntegrationTestsMixin
:
class
GenerationIntegrationTestsMixin
:
# To be populated by the child classes
# To be populated by the child classes
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
None
,
"AutoModelForSeq2SeqLM"
:
None
,
"AutoModelForSeq2SeqLM"
:
None
,
"LogitsProcessorList"
:
None
,
"LogitsProcessorList"
:
None
,
"MinLengthLogitsProcessor"
:
None
,
"MinLengthLogitsProcessor"
:
None
,
...
@@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
...
@@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
bart_model
.
config
.
min_length
=
None
bart_model
.
config
.
min_length
=
None
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
def
test_max_new_tokens_encoder_decoder
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForSeq2SeqLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
bart_model
=
bart_model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForCausalLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
gpt2_model
=
gpt2_model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_encoder_decoder_generate_with_inputs_embeds
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForSeq2SeqLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
5
)
model
.
config
.
eos_token_id
=
None
input_ids
=
tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
model
=
model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
output_sequences
=
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
# make sure model generated correctly until `max_length`
self
.
assertEqual
(
output_sequences
.
shape
,
(
1
,
5
))
tests/generation/test_tf_utils.py
View file @
1e4cf8bb
...
@@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_tf_available
():
if
is_tf_available
():
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
TFAutoModelForCausalLM
,
"AutoModelForSeq2SeqLM"
:
TFAutoModelForSeq2SeqLM
,
"AutoModelForSeq2SeqLM"
:
TFAutoModelForSeq2SeqLM
,
"LogitsProcessorList"
:
TFLogitsProcessorList
,
"LogitsProcessorList"
:
TFLogitsProcessorList
,
"MinLengthLogitsProcessor"
:
TFMinLengthLogitsProcessor
,
"MinLengthLogitsProcessor"
:
TFMinLengthLogitsProcessor
,
...
...
tests/generation/test_utils.py
View file @
1e4cf8bb
...
@@ -40,7 +40,6 @@ if is_torch_available():
...
@@ -40,7 +40,6 @@ if is_torch_available():
ImageGPTForCausalImageModeling
,
ImageGPTForCausalImageModeling
,
Speech2TextForConditionalGeneration
,
Speech2TextForConditionalGeneration
,
SpeechEncoderDecoderModel
,
SpeechEncoderDecoderModel
,
T5ForConditionalGeneration
,
VisionEncoderDecoderModel
,
VisionEncoderDecoderModel
,
top_k_top_p_filtering
,
top_k_top_p_filtering
,
)
)
...
@@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_torch_available
():
if
is_torch_available
():
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
AutoModelForCausalLM
,
"AutoModelForSeq2SeqLM"
:
AutoModelForSeq2SeqLM
,
"AutoModelForSeq2SeqLM"
:
AutoModelForSeq2SeqLM
,
"LogitsProcessorList"
:
LogitsProcessorList
,
"LogitsProcessorList"
:
LogitsProcessorList
,
"MinLengthLogitsProcessor"
:
MinLengthLogitsProcessor
,
"MinLengthLogitsProcessor"
:
MinLengthLogitsProcessor
,
...
@@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
def
test_max_new_tokens_encoder_decoder
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_t5
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
t5_model
=
T5ForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
).
to
(
torch_device
)
input_ids
=
t5_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
56
])
max_new_tokens
=
3
t5_model
.
config
.
max_length
=
20
t5_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
t5_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
t5_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 56 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
59
])
# Encoder decoder call > 20
outputs
=
t5_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_bart
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_gptj
(
self
):
article
=
"""Justin Timberlake."""
gptj_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gptj"
)
gptj_model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gptj"
).
to
(
torch_device
)
input_ids
=
gptj_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gptj_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gptj_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gptj_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_gpt2
(
self
):
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only
(
self
):
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_encoder_decoder_generate_with_inputs_embeds
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
5
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
None
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
output_sequences
=
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
# make sure model generated correctly until `max_length`
self
.
assertEqual
(
output_sequences
.
shape
,
(
1
,
5
))
def
test_encoder_decoder_generate_attention_mask
(
self
):
def
test_encoder_decoder_generate_attention_mask
(
self
):
articles
=
[
"Timberlake"
,
"Jessica Biel, welcome to parenthood among other things"
]
articles
=
[
"Timberlake"
,
"Jessica Biel, welcome to parenthood among other things"
]
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment