Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e4cf8bb
Unverified
Commit
1e4cf8bb
authored
Feb 07, 2023
by
Joao Gante
Committed by
GitHub
Feb 07, 2023
Browse files
Generate: TF can now generate from embeddings in encoder-decoder models (#21475)
parent
12eb528b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
196 deletions
+183
-196
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+91
-19
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+90
-0
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-177
No files found.
src/transformers/generation/tf_utils.py
View file @
1e4cf8bb
...
@@ -664,9 +664,11 @@ class TFGenerationMixin:
...
@@ -664,9 +664,11 @@ class TFGenerationMixin:
)
)
# 4. Define model inputs
# 4. Define model inputs
input_ids
=
self
.
_prepare_model_inputs
(
input_ids
,
generation_config
.
bos_token_id
)
inputs_tensor
,
model_input_name
,
model_kwargs
=
self
.
_prepare_model_inputs
(
input_ids
,
generation_config
.
bos_token_id
,
model_kwargs
)
# inputs_ids now has to be defined and cannot be None anymore
# inputs_ids now has to be defined and cannot be None anymore
batch_size
=
shape_list
(
input
_ids
)[
0
]
batch_size
=
shape_list
(
input
s_tensor
)[
0
]
# 5. Prepare other model kwargs
# 5. Prepare other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
...
@@ -678,23 +680,26 @@ class TFGenerationMixin:
...
@@ -678,23 +680,26 @@ class TFGenerationMixin:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
input
_ids
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
input
s_tensor
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
)
)
# decoder-only models should use left-padding for generation
# decoder-only models should use left-padding for generation
if
not
self
.
config
.
is_encoder_decoder
:
if
not
self
.
config
.
is_encoder_decoder
:
if
generation_config
.
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
if
generation_config
.
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
input
_ids
[:,
-
1
]
==
generation_config
.
pad_token_id
input
s_tensor
[:,
-
1
]
==
generation_config
.
pad_token_id
):
):
logger
.
warning
(
logger
.
warning
(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
)
if
self
.
config
.
is_encoder_decoder
and
"encoder_outputs"
not
in
model_kwargs
:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
)
# 6. Prepare model inputs which will be used for auto-regressive generation
# 6. Prepare model inputs which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
batch_size
,
...
@@ -702,6 +707,9 @@ class TFGenerationMixin:
...
@@ -702,6 +707,9 @@ class TFGenerationMixin:
bos_token_id
=
generation_config
.
bos_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
)
)
else
:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids
=
inputs_tensor
# 7. Prepare `max_length` depending on other stopping criteria.
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
...
@@ -924,7 +932,9 @@ class TFGenerationMixin:
...
@@ -924,7 +932,9 @@ class TFGenerationMixin:
else
:
else
:
return
tf
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
tf
.
int32
)
return
tf
.
ones
(
inputs
.
shape
[:
2
],
dtype
=
tf
.
int32
)
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
)
->
Dict
[
str
,
Any
]:
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
# get encoder and store encoder outputs
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
...
@@ -938,7 +948,9 @@ class TFGenerationMixin:
...
@@ -938,7 +948,9 @@ class TFGenerationMixin:
# vision models don't use `attention_mask`.
# vision models don't use `attention_mask`.
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
self
.
main_input_name
]
=
inputs_tensor
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
encoder_kwargs
[
self
.
main_input_name
]
=
None
encoder_outputs
=
encoder
(
**
encoder_kwargs
)
encoder_outputs
=
encoder
(
**
encoder_kwargs
)
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
...
@@ -1007,19 +1019,79 @@ class TFGenerationMixin:
...
@@ -1007,19 +1019,79 @@ class TFGenerationMixin:
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
def
_prepare_model_inputs
(
self
,
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
):
def
_prepare_model_inputs
(
# TODO(Patrick) - adapt this function when making `generate` more flexible
self
,
# for all kinds of input types
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
if
inputs
is
None
:
bos_token_id
:
Optional
[
int
]
=
None
,
# if no `inputs` are passed create prompt of size (1,1) filled with BOS token
model_kwargs
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
if
not
isinstance
(
bos_token_id
,
int
)
or
bos_token_id
<
0
:
)
->
Tuple
[
tf
.
Tensor
,
Optional
[
str
],
Dict
[
str
,
tf
.
Tensor
]]:
raise
ValueError
(
"""
"you should either supply a context to complete as `input_ids` input "
This function extracts the model-specific `inputs` for generation.
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
"""
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if
(
self
.
config
.
is_encoder_decoder
and
hasattr
(
self
,
"encoder"
)
and
hasattr
(
self
.
encoder
,
"main_input_name"
)
and
self
.
encoder
.
main_input_name
!=
self
.
main_input_name
):
input_name
=
self
.
encoder
.
main_input_name
else
:
input_name
=
self
.
main_input_name
model_kwargs
=
{
k
:
v
for
k
,
v
in
model_kwargs
.
items
()
if
v
is
not
None
or
k
!=
input_name
}
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg
=
model_kwargs
.
pop
(
input_name
,
None
)
if
inputs_kwarg
is
not
None
and
inputs
is
not
None
:
raise
ValueError
(
f
"`inputs`:
{
inputs
}
` were passed alongside
{
input_name
}
which is not allowed."
f
"Make sure to either pass
{
inputs
}
or
{
input_name
}
=..."
)
elif
inputs_kwarg
is
not
None
:
inputs
=
inputs_kwarg
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if
input_name
==
"input_ids"
and
"inputs_embeds"
in
model_kwargs
:
if
not
self
.
config
.
is_encoder_decoder
:
has_inputs_embeds_forwarding
=
"inputs_embeds"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
()
)
)
return
tf
.
cast
(
tf
.
fill
((
1
,
1
),
bos_token_id
),
dtype
=
tf
.
int32
)
if
not
has_inputs_embeds_forwarding
:
raise
ValueError
(
f
"You passed `inputs_embeds` to `.generate()`, but the model class
{
self
.
__class__
.
__name__
}
"
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
return
inputs
,
input_name
,
model_kwargs
return
inputs
def
_prepare_input_ids_for_generation
(
self
,
bos_token_id
:
Optional
[
int
],
encoder_outputs
:
Optional
[
ModelOutput
]
)
->
tf
.
Tensor
:
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
return
tf
.
ones
(
shape
,
dtype
=
tf
.
int32
)
*
-
100
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
tf
.
ones
((
1
,
1
),
dtype
=
tf
.
int32
)
*
bos_token_id
@
staticmethod
@
staticmethod
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
...
...
tests/generation/test_framework_agnostic.py
View file @
1e4cf8bb
...
@@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
...
@@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.testing_utils
import
torch_device
class
GenerationIntegrationTestsMixin
:
class
GenerationIntegrationTestsMixin
:
# To be populated by the child classes
# To be populated by the child classes
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
None
,
"AutoModelForSeq2SeqLM"
:
None
,
"AutoModelForSeq2SeqLM"
:
None
,
"LogitsProcessorList"
:
None
,
"LogitsProcessorList"
:
None
,
"MinLengthLogitsProcessor"
:
None
,
"MinLengthLogitsProcessor"
:
None
,
...
@@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
...
@@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
bart_model
.
config
.
min_length
=
None
bart_model
.
config
.
min_length
=
None
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
def
test_max_new_tokens_encoder_decoder
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForSeq2SeqLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
bart_model
=
bart_model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForCausalLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
gpt2_model
=
gpt2_model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_encoder_decoder_generate_with_inputs_embeds
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForSeq2SeqLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
5
)
model
.
config
.
eos_token_id
=
None
input_ids
=
tokenizer
(
article
,
return_tensors
=
return_tensors
).
input_ids
if
is_pt
:
model
=
model
.
to
(
torch_device
)
input_ids
=
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
output_sequences
=
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
# make sure model generated correctly until `max_length`
self
.
assertEqual
(
output_sequences
.
shape
,
(
1
,
5
))
tests/generation/test_tf_utils.py
View file @
1e4cf8bb
...
@@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_tf_available
():
if
is_tf_available
():
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
TFAutoModelForCausalLM
,
"AutoModelForSeq2SeqLM"
:
TFAutoModelForSeq2SeqLM
,
"AutoModelForSeq2SeqLM"
:
TFAutoModelForSeq2SeqLM
,
"LogitsProcessorList"
:
TFLogitsProcessorList
,
"LogitsProcessorList"
:
TFLogitsProcessorList
,
"MinLengthLogitsProcessor"
:
TFMinLengthLogitsProcessor
,
"MinLengthLogitsProcessor"
:
TFMinLengthLogitsProcessor
,
...
...
tests/generation/test_utils.py
View file @
1e4cf8bb
...
@@ -40,7 +40,6 @@ if is_torch_available():
...
@@ -40,7 +40,6 @@ if is_torch_available():
ImageGPTForCausalImageModeling
,
ImageGPTForCausalImageModeling
,
Speech2TextForConditionalGeneration
,
Speech2TextForConditionalGeneration
,
SpeechEncoderDecoderModel
,
SpeechEncoderDecoderModel
,
T5ForConditionalGeneration
,
VisionEncoderDecoderModel
,
VisionEncoderDecoderModel
,
top_k_top_p_filtering
,
top_k_top_p_filtering
,
)
)
...
@@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_torch_available
():
if
is_torch_available
():
framework_dependent_parameters
=
{
framework_dependent_parameters
=
{
"AutoModelForCausalLM"
:
AutoModelForCausalLM
,
"AutoModelForSeq2SeqLM"
:
AutoModelForSeq2SeqLM
,
"AutoModelForSeq2SeqLM"
:
AutoModelForSeq2SeqLM
,
"LogitsProcessorList"
:
LogitsProcessorList
,
"LogitsProcessorList"
:
LogitsProcessorList
,
"MinLengthLogitsProcessor"
:
MinLengthLogitsProcessor
,
"MinLengthLogitsProcessor"
:
MinLengthLogitsProcessor
,
...
@@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
def
test_max_new_tokens_encoder_decoder
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_t5
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
t5_model
=
T5ForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
).
to
(
torch_device
)
input_ids
=
t5_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
56
])
max_new_tokens
=
3
t5_model
.
config
.
max_length
=
20
t5_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
t5_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
t5_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 56 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
59
])
# Encoder decoder call > 20
outputs
=
t5_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_bart
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
29
])
max_new_tokens
=
3
bart_model
.
config
.
max_length
=
20
bart_model
.
config
.
eos_token_id
=
None
# Encoder decoder call
outputs
=
bart_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
4
])
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 29 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
32
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS + 20 + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_gptj
(
self
):
article
=
"""Justin Timberlake."""
gptj_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gptj"
)
gptj_model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gptj"
).
to
(
torch_device
)
input_ids
=
gptj_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gptj_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gptj_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gptj_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only_contrastive_search_gpt2
(
self
):
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
,
penalty_alpha
=
0.6
,
top_k
=
4
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_max_new_tokens_decoder_only
(
self
):
article
=
"""Justin Timberlake."""
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
gpt2_model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
self
.
assertEqual
(
list
(
input_ids
.
shape
),
[
1
,
9
])
max_new_tokens
=
3
gpt2_model
.
config
.
max_length
=
20
# call < 20
outputs
=
gpt2_model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
)
# 9 input_ids + 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
12
])
# call > 20
outputs
=
gpt2_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
# 1 BOS token + 23 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
24
])
def
test_encoder_decoder_generate_with_inputs_embeds
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
5
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
None
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
output_sequences
=
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
# make sure model generated correctly until `max_length`
self
.
assertEqual
(
output_sequences
.
shape
,
(
1
,
5
))
def
test_encoder_decoder_generate_attention_mask
(
self
):
def
test_encoder_decoder_generate_attention_mask
(
self
):
articles
=
[
"Timberlake"
,
"Jessica Biel, welcome to parenthood among other things"
]
articles
=
[
"Timberlake"
,
"Jessica Biel, welcome to parenthood among other things"
]
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment