"...networks/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "64948a2ea2c1a0c51656425e1cd2dd66f8b88d08"
Unverified Commit eb6c59bc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF supports multiple eos tokens (#21571)

parent c836f772
......@@ -1230,7 +1230,7 @@ class TFGenerationMixin:
) -> tf.Tensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
shape = encoder_outputs.last_hidden_state.shape[:-1]
return tf.ones(shape, dtype=tf.int32) * -100
if bos_token_id is None:
......@@ -1515,8 +1515,8 @@ class TFGenerationMixin:
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -1575,6 +1575,8 @@ class TFGenerationMixin:
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -1660,7 +1662,13 @@ class TFGenerationMixin:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
next_token_is_eos = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1)
),
axis=0,
)
finished_sequences = finished_sequences | next_token_is_eos
# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
......@@ -1776,8 +1784,8 @@ class TFGenerationMixin:
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
......@@ -1852,6 +1860,8 @@ class TFGenerationMixin:
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -1943,7 +1953,13 @@ class TFGenerationMixin:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
next_token_is_eos = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1)
),
axis=0,
)
finished_sequences = finished_sequences | next_token_is_eos
# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
......@@ -2079,8 +2095,8 @@ class TFGenerationMixin:
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
......@@ -2180,6 +2196,8 @@ class TFGenerationMixin:
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
)
......@@ -2401,9 +2419,18 @@ class TFGenerationMixin:
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id
if eos_token_id is None:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len].shape, dtype=tf.bool)
else:
eos_in_next_token = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(
topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape
),
tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1),
),
axis=0,
)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
shape_list(eos_in_next_token),
......@@ -2649,8 +2676,8 @@ class TFGenerationMixin:
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2700,6 +2727,8 @@ class TFGenerationMixin:
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2924,7 +2953,13 @@ class TFGenerationMixin:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
next_token_is_eos = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1)
),
axis=0,
)
finished_sequences = finished_sequences | next_token_is_eos
# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
......
......@@ -1702,8 +1702,8 @@ class GenerationMixin:
used to tell if the generation loop should stop.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2057,8 +2057,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2306,8 +2306,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2574,8 +2574,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2902,8 +2902,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -3230,8 +3230,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -3613,8 +3613,8 @@ class GenerationMixin:
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......
......@@ -12,11 +12,15 @@ class GenerationIntegrationTestsMixin:
# To be populated by the child classes
framework_dependent_parameters = {
"AutoModelForCausalLM": None,
"AutoModelForSpeechSeq2Seq": None,
"AutoModelForSeq2SeqLM": None,
"AutoModelForVision2Seq": None,
"LogitsProcessorList": None,
"MinLengthLogitsProcessor": None,
"create_tensor_fn": None,
"floats_tensor": None,
"return_tensors": None,
"set_seed": None,
}
def test_validate_generation_inputs(self):
......@@ -486,3 +490,171 @@ class GenerationIntegrationTestsMixin:
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
with self.assertRaises(ValueError):
model.generate(input_ids, input_ids=input_ids)
def test_generate_too_many_encoder_kwargs(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
article = """I need input_ids to generate"""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10)
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
def test_generate_input_features_as_encoder_kwarg(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
is_pt = not model_cls.__name__.startswith("TF")
input_features = floats_tensor((3, 80, 60))
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-WhisperForConditionalGeneration")
if is_pt:
input_features.to(torch_device)
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5)
output_sequences = model.generate(input_features, max_length=5)
if is_pt:
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
output_sequences = output_sequences.cpu().numpy()
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
self.assertEqual(output_sequences.shape, (3, 5))
def test_generate_pixel_values_as_encoder_kwarg(self):
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
is_pt = not model_cls.__name__.startswith("TF")
pixel_values = floats_tensor((2, 3, 30, 30))
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
model.config.decoder.eos_token_id = None
if is_pt:
pixel_values = pixel_values.to(torch_device)
model = model.to(torch_device)
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5)
output_sequences = model.generate(pixel_values, max_length=5)
if is_pt:
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
output_sequences = output_sequences.cpu().numpy()
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_encoder_outputs_attention_mask(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
is_pt = not model_cls.__name__.startswith("TF")
input_features = floats_tensor((3, 80, 60))
attention_mask = create_tensor_fn(np.ones(input_features.shape))
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-WhisperForConditionalGeneration")
if is_pt:
input_features = input_features.to(torch_device)
attention_mask = attention_mask.to(torch_device)
model = model.to(torch_device)
encoder = model.get_encoder()
encoder_outputs = encoder(input_features)
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs)
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
if is_pt:
output_sequences_no_mask = output_sequences_no_mask.cpu().numpy()
output_sequences_with_mask = output_sequences_with_mask.cpu().numpy()
self.assertTrue(np.array_equal(output_sequences_no_mask, output_sequences_with_mask))
def test_eos_token_id_int_and_list_greedy_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
}
expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors=return_tensors)
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
if is_pt:
model = model.to(torch_device)
tokens = tokens.to(torch_device)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_contrastive_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
"penalty_alpha": 0.6,
"top_k": 4,
}
expectation = 17
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors=return_tensors)
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
if is_pt:
model = model.to(torch_device)
tokens = tokens.to(torch_device)
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
eos_token_id = [225, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_beam_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
generation_kwargs = {
"do_sample": False,
"num_beams": 3,
}
expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors=return_tensors)
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
if is_pt:
model = model.to(torch_device)
tokens = tokens.to(torch_device)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
unpadded_correct_condition = expectation == len(generated_tokens[0])
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
)
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
unpadded_correct_condition = expectation == len(generated_tokens[0])
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
)
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
......@@ -19,6 +19,7 @@ import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_modeling_tf_common import floats_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
......@@ -26,8 +27,11 @@ if is_tf_available():
import tensorflow as tf
from transformers import (
AutoTokenizer,
TFAutoModelForCausalLM,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSpeechSeq2Seq,
TFAutoModelForVision2Seq,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
tf_top_k_top_p_filtering,
......@@ -136,15 +140,19 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForCausalLM": TFAutoModelForCausalLM,
"AutoModelForSpeechSeq2Seq": TFAutoModelForSpeechSeq2Seq,
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"AutoModelForVision2Seq": TFAutoModelForVision2Seq,
"LogitsProcessorList": TFLogitsProcessorList,
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
"create_tensor_fn": tf.convert_to_tensor,
"floats_tensor": floats_tensor,
"return_tensors": "tf",
}
@slow
def test_generate_tf_function_export_fixed_input_length(self):
# TF-only test: tf.saved_model export
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_length = 2
max_new_tokens = 2
......@@ -187,6 +195,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
@slow
def test_generate_tf_function_export_fixed_batch_size(self):
# TF-only test: tf.saved_model export
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch_size = 1
max_new_tokens = 2
......@@ -226,3 +235,32 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has PT equivalent: this test relies on random sampling
generation_kwargs = {
"do_sample": True,
"num_beams": 1,
"top_p": 0.7,
"top_k": 10,
"temperature": 0.7,
}
expectation = 14
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="tf")
model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
eos_token_id = 638
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(0)
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
eos_token_id = [638, 198]
with tf.device(":/CPU:0"):
tf.random.set_seed(0)
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
......@@ -30,15 +30,15 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
AutoTokenizer,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
top_k_top_p_filtering,
)
from transformers.generation import (
......@@ -1790,10 +1790,13 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoModelForSpeechSeq2Seq": AutoModelForSpeechSeq2Seq,
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"AutoModelForVision2Seq": AutoModelForVision2Seq,
"LogitsProcessorList": LogitsProcessorList,
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
"create_tensor_fn": torch.tensor,
"floats_tensor": floats_tensor,
"return_tensors": "pt",
}
......@@ -2093,7 +2096,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
def test_generate_non_nlp_input_ids_as_kwarg(self):
# PT-only test: AFAIK there is no non-NLP model architecture in TF that supports `input_ids` as its only input
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
model = ImageGPTForCausalImageModeling.from_pretrained(
"hf-internal-testing/tiny-random-imagegpt", max_length=10
).to(torch_device)
......@@ -2105,17 +2108,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 10))
def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10).to(
torch_device
)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
def test_generate_input_values_as_encoder_kwarg(self):
# PT-only test: AFAIK there's no generate-capable architecture in TF that supports `input_values` as its input
input_values = floats_tensor((2, 250))
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
......@@ -2125,43 +2119,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_input_features_as_encoder_kwarg(self):
input_features = floats_tensor((3, 20, 24))
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
output_sequences = model.generate(input_features, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 5))
def test_generate_pixel_values_as_encoder_kwarg(self):
pixel_values = floats_tensor((2, 3, 30, 30))
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
output_sequences = model.generate(pixel_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_encoder_outputs_attention_mask(self):
input_values = floats_tensor((2, 250)).to(torch_device)
attention_mask = torch.ones_like(input_values)
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
encoder = model.get_encoder()
encoder_outputs = encoder(input_values)
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
output_sequences_with_mask = output_sequences_with_mask.cpu()
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
def test_transition_scores_group_beam_search_encoder_decoder(self):
# PT-only test: TF doesn't have group beam search
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
......@@ -2188,64 +2147,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_log_scores_sample_decoder_only(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=15,
return_dict_in_generate=True,
do_sample=False,
output_scores=True,
)
# decoder-only starts generating from `input_ids`
begin_generation = inputs.input_ids.shape[-1]
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
def test_log_scores_sample_encoder_decoder(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=3,
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
output_scores=True,
)
# encoder-decoder has one decoder_start_token_id by default
begin_generation = 1
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
# exactly the example provided in the docstrings of beam search, which previously
# failed after directly copying from it. Refer to PR #15555
tokenizer = AutoTokenizer.from_pretrained("t5-base")
......@@ -2288,6 +2192,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
def test_constrained_beam_search(self):
# PT-only test: TF doesn't have constrained beam search
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
......@@ -2325,6 +2230,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
def test_constrained_beam_search_mixed(self):
# PT-only test: TF doesn't have constrained beam search
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
......@@ -2365,6 +2271,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
def test_constrained_beam_search_mixed_mixin(self):
# PT-only test: TF doesn't have constrained beam search
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
......@@ -2402,6 +2309,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
......@@ -2426,6 +2334,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
def test_constrained_beam_search_example_integration(self):
# PT-only test: TF doesn't have constrained beam search
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
......@@ -2469,6 +2378,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
# PT-only test: TF doesn't have constrained beam search
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
......@@ -2509,6 +2419,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
model.generate(input_ids, force_words_ids=[[[-1]]])
def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
articles = ["Foo", "Bar Baz"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......@@ -2533,55 +2444,32 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
def test_eos_token_id_int_and_list_greedy_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_contrastive_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
"penalty_alpha": 0.6,
"top_k": 4,
}
expectation = 17
def test_generate_from_input_embeds_decoder_only(self):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
torch.manual_seed(0)
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
eos_token_id = [225, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {
"do_sample": True,
"num_beams": 1,
......@@ -2591,11 +2479,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
}
expectation = 15
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 846
......@@ -2606,49 +2493,3 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id = [846, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_beam_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 3,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_generate_from_input_embeds_decoder_only(self):
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment