"docs/vscode:/vscode.git/clone" did not exist on "f250beb8aac83009c70ff01ae8568384683d0f3c"
Unverified Commit 425ba56c authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Clean-up generation tests after moving methods to private (#29582)

* clean-up tests

* refine comments

* fix musicgen tests

* make style

* remove slow decorator from a test

* more clean-up

* fix other failing tests
parent 56baa033
...@@ -59,33 +59,21 @@ if is_torch_available(): ...@@ -59,33 +59,21 @@ if is_torch_available():
BeamSampleEncoderDecoderOutput, BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput, BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput, BeamSearchEncoderDecoderOutput,
BeamSearchScorer,
ConstrainedBeamSearchScorer,
DisjunctiveConstraint, DisjunctiveConstraint,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput, GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput, GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput, GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput, GenerateEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MaxLengthCriteria, MaxLengthCriteria,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PhrasalConstraint, PhrasalConstraint,
RepetitionPenaltyLogitsProcessor,
SampleDecoderOnlyOutput, SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput, SampleEncoderDecoderOutput,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
) )
from transformers.generation.utils import _speculative_sampling from transformers.generation.utils import _speculative_sampling
...@@ -104,7 +92,10 @@ class GenerationTesterMixin: ...@@ -104,7 +92,10 @@ class GenerationTesterMixin:
input_ids = input_ids[:batch_size, :sequence_length] input_ids = input_ids[:batch_size, :sequence_length]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 if config.is_encoder_decoder:
max_length = 4
else:
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
if isinstance(config.eos_token_id, int): if isinstance(config.eos_token_id, int):
...@@ -112,16 +103,19 @@ class GenerationTesterMixin: ...@@ -112,16 +103,19 @@ class GenerationTesterMixin:
config.pad_token_id = config.eos_token_id[0] config.pad_token_id = config.eos_token_id[0]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated
config.eos_token_id = None
config.forced_eos_token_id = None
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask, max_length
@staticmethod @staticmethod
def _get_logits_processor_and_kwargs( def _get_logits_processor_and_warper_kwargs(
input_length, input_length,
eos_token_id,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None, max_length=None,
diversity_penalty=None,
): ):
process_kwargs = { process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1, "min_length": input_length + 1 if max_length is None else max_length - 1,
...@@ -133,78 +127,21 @@ class GenerationTesterMixin: ...@@ -133,78 +127,21 @@ class GenerationTesterMixin:
if forced_bos_token_id is None and forced_eos_token_id is None: if forced_bos_token_id is None and forced_eos_token_id is None:
process_kwargs["no_repeat_ngram_size"] = 2 process_kwargs["no_repeat_ngram_size"] = 2
# NOTE: the order of operations here should match `generate` for accurate testing
logits_processor = LogitsProcessorList(
(
[
HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
]
if diversity_penalty is not None
else []
)
+ (
[
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
]
if eos_token_id is not None
else []
)
+ (
[
ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
]
if forced_bos_token_id is not None
else []
)
+ (
[ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
if forced_eos_token_id is not None
else []
)
+ [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)]
+ (
[NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])]
if forced_bos_token_id is None and forced_eos_token_id is None
else []
)
+ [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])]
+ [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures
)
return process_kwargs, logits_processor
@staticmethod
def _get_warper_and_kwargs(num_beams):
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
logits_warper = LogitsProcessorList( return process_kwargs, warp_kwargs
[
TemperatureLogitsWarper(warp_kwargs["temperature"]),
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
]
)
return warp_kwargs, logits_warper
@staticmethod @staticmethod
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): def _get_beam_kwargs(num_return_sequences=1):
beam_kwargs = { beam_kwargs = {
"early_stopping": False, "early_stopping": False,
"length_penalty": 2.0, "length_penalty": 2.0,
"num_beams": 2, "num_beams": 2,
"num_return_sequences": num_return_sequences, "num_return_sequences": num_return_sequences,
} }
beam_scorer = BeamSearchScorer( return beam_kwargs
batch_size=batch_size,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod @staticmethod
def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): def _get_diverse_beam_kwargs(num_return_sequences=1):
beam_kwargs = { beam_kwargs = {
"early_stopping": False, "early_stopping": False,
"length_penalty": 2.0, "length_penalty": 2.0,
...@@ -213,35 +150,17 @@ class GenerationTesterMixin: ...@@ -213,35 +150,17 @@ class GenerationTesterMixin:
"num_beam_groups": 2, # one beam per group "num_beam_groups": 2, # one beam per group
"diversity_penalty": 2.0, "diversity_penalty": 2.0,
} }
beam_scorer = BeamSearchScorer( return beam_kwargs
batch_size=batch_size,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=beam_kwargs["num_beam_groups"],
)
return beam_kwargs, beam_scorer
@staticmethod @staticmethod
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1): def _get_constrained_beam_kwargs(num_return_sequences=1):
beam_kwargs = { beam_kwargs = {
"early_stopping": False, "early_stopping": False,
"length_penalty": 2.0, "length_penalty": 2.0,
"num_beams": num_return_sequences * 4, "num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences, "num_return_sequences": num_return_sequences,
} }
beam_scorer = ConstrainedBeamSearchScorer( return beam_kwargs
batch_size=batch_size,
constraints=constraints,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod @staticmethod
def _get_encoder_outputs( def _get_encoder_outputs(
...@@ -273,17 +192,13 @@ class GenerationTesterMixin: ...@@ -273,17 +192,13 @@ class GenerationTesterMixin:
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
if model.config.is_encoder_decoder: logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id, forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id, forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
...@@ -299,31 +214,7 @@ class GenerationTesterMixin: ...@@ -299,31 +214,7 @@ class GenerationTesterMixin:
**model_kwargs, **model_kwargs,
) )
if model.config.is_encoder_decoder: return output_generate
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
input_ids,
max_length=max_length,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_greedy, output_generate
def _sample_generate( def _sample_generate(
self, self,
...@@ -332,8 +223,6 @@ class GenerationTesterMixin: ...@@ -332,8 +223,6 @@ class GenerationTesterMixin:
attention_mask, attention_mask,
max_length, max_length,
num_return_sequences, num_return_sequences,
logits_processor,
logits_warper,
logits_warper_kwargs, logits_warper_kwargs,
process_kwargs, process_kwargs,
output_scores=False, output_scores=False,
...@@ -360,38 +249,7 @@ class GenerationTesterMixin: ...@@ -360,38 +249,7 @@ class GenerationTesterMixin:
**model_kwargs, **model_kwargs,
) )
torch.manual_seed(0) return output_generate
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_sample, output_generate
def _beam_search_generate( def _beam_search_generate(
self, self,
...@@ -399,9 +257,7 @@ class GenerationTesterMixin: ...@@ -399,9 +257,7 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask, attention_mask,
max_length, max_length,
beam_scorer,
beam_kwargs, beam_kwargs,
logits_processor,
logits_process_kwargs, logits_process_kwargs,
output_scores=False, output_scores=False,
output_logits=False, output_logits=False,
...@@ -424,37 +280,7 @@ class GenerationTesterMixin: ...@@ -424,37 +280,7 @@ class GenerationTesterMixin:
**model_kwargs, **model_kwargs,
) )
# beam_search does not automatically interleave `batch_size` dim for `num_beams` return output_generate
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_search = model.beam_search(
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_search
def _beam_sample_generate( def _beam_sample_generate(
self, self,
...@@ -462,9 +288,7 @@ class GenerationTesterMixin: ...@@ -462,9 +288,7 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask, attention_mask,
max_length, max_length,
beam_scorer,
beam_kwargs, beam_kwargs,
logits_warper,
logits_warper_kwargs, logits_warper_kwargs,
output_scores=False, output_scores=False,
output_logits=False, output_logits=False,
...@@ -487,44 +311,8 @@ class GenerationTesterMixin: ...@@ -487,44 +311,8 @@ class GenerationTesterMixin:
**logits_warper_kwargs, **logits_warper_kwargs,
**model_kwargs, **model_kwargs,
) )
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
# prevent flaky generation test failures
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_sample return output_generate
def _group_beam_search_generate( def _group_beam_search_generate(
self, self,
...@@ -532,9 +320,7 @@ class GenerationTesterMixin: ...@@ -532,9 +320,7 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask, attention_mask,
max_length, max_length,
beam_scorer,
beam_kwargs, beam_kwargs,
logits_processor,
logits_process_kwargs, logits_process_kwargs,
output_scores=False, output_scores=False,
output_logits=False, output_logits=False,
...@@ -557,37 +343,7 @@ class GenerationTesterMixin: ...@@ -557,37 +343,7 @@ class GenerationTesterMixin:
**model_kwargs, **model_kwargs,
) )
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams` return output_generate
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.group_beam_search(
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
def _constrained_beam_search_generate( def _constrained_beam_search_generate(
self, self,
...@@ -595,10 +351,8 @@ class GenerationTesterMixin: ...@@ -595,10 +351,8 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask, attention_mask,
max_length, max_length,
constrained_beam_scorer,
constraints, constraints,
beam_kwargs, beam_kwargs,
logits_processor,
logits_process_kwargs, logits_process_kwargs,
output_scores=False, output_scores=False,
output_logits=False, output_logits=False,
...@@ -622,37 +376,7 @@ class GenerationTesterMixin: ...@@ -622,37 +376,7 @@ class GenerationTesterMixin:
**model_kwargs, **model_kwargs,
) )
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams` return output_generate
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=constrained_beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.constrained_beam_search(
input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0),
constrained_beam_scorer,
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
def _contrastive_generate( def _contrastive_generate(
self, self,
...@@ -671,17 +395,13 @@ class GenerationTesterMixin: ...@@ -671,17 +395,13 @@ class GenerationTesterMixin:
"top_k": 5, "top_k": 5,
} }
if model.config.is_encoder_decoder: logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id, forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id, forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
...@@ -698,52 +418,26 @@ class GenerationTesterMixin: ...@@ -698,52 +418,26 @@ class GenerationTesterMixin:
**contrastive_search_kwargs, **contrastive_search_kwargs,
) )
if model.config.is_encoder_decoder: return output_generate
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
output_contrastive = model.contrastive_search(
input_ids,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
**contrastive_search_kwargs,
)
return output_contrastive, output_generate
def test_greedy_generate(self): def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
) )
self.assertListEqual(output_greedy.tolist(), output_generate.tolist())
self.assertTrue(output_generate.shape[-1] == max_length)
def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -756,26 +450,19 @@ class GenerationTesterMixin: ...@@ -756,26 +450,19 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self._check_outputs(output_generate, input_ids, model.config)
for output in (output_greedy, output_generate):
self._check_outputs(output, input_ids, model.config)
def test_greedy_generate_dict_outputs_use_cache(self): def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
...@@ -784,7 +471,7 @@ class GenerationTesterMixin: ...@@ -784,7 +471,7 @@ class GenerationTesterMixin:
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -796,82 +483,58 @@ class GenerationTesterMixin: ...@@ -796,82 +483,58 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
for output in (output_greedy, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_sample_generate(self): def test_sample_generate(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id, forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id, forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal output_generate = self._sample_generate(
output_sample, output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
) )
self.assertListEqual(output_sample.tolist(), output_generate.tolist())
# check `generate()` and `sample()` yield equal results for `num_return_sequences` self.assertTrue(output_generate.shape[-1] == max_length)
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertListEqual(output_sample.tolist(), output_generate.tolist())
def test_sample_generate_dict_output(self): def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id, forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id, forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
num_return_sequences=2, num_return_sequences=2,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
output_scores=True, output_scores=True,
...@@ -882,75 +545,43 @@ class GenerationTesterMixin: ...@@ -882,75 +545,43 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
for output in (output_sample, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=2)
def test_beam_search_generate(self): def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
) )
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_beam_kwargs()
# check `generate()` and `beam_search()` are equal output_generate = self._beam_search_generate(
output_generate, output_beam_search = self._beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
) )
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate, output_beam_search = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
def test_beam_search_generate_dict_output(self): def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -959,33 +590,24 @@ class GenerationTesterMixin: ...@@ -959,33 +590,24 @@ class GenerationTesterMixin:
# disable cache # disable cache
config.use_cache = False config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
) )
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_beam_kwargs()
output_generate, output_beam_search = self._beam_search_generate( output_generate = self._beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -993,39 +615,24 @@ class GenerationTesterMixin: ...@@ -993,39 +615,24 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self.assertTrue( self._check_outputs(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
) )
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_beam_search_generate_dict_outputs_use_cache(self): def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest("This model doesn't support caching")
...@@ -1033,28 +640,25 @@ class GenerationTesterMixin: ...@@ -1033,28 +640,25 @@ class GenerationTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
) )
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_beam_kwargs()
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_beam, output_generate = self._beam_search_generate( output_generate = self._beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -1062,12 +666,10 @@ class GenerationTesterMixin: ...@@ -1062,12 +666,10 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self._check_outputs(
for output in (output_beam, output_generate): output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
self._check_outputs( )
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
@require_accelerate @require_accelerate
@require_torch_multi_accelerator @require_torch_multi_accelerator
...@@ -1097,32 +699,24 @@ class GenerationTesterMixin: ...@@ -1097,32 +699,24 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
# check `generate()` and `beam_search()` are equal
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_beam_kwargs()
output_generate, output_beam_sample = self._beam_sample_generate( output_generate = self._beam_sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
) )
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist())
self.assertTrue(output_generate.shape[-1] == max_length)
def test_beam_sample_generate_dict_output(self): def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1131,27 +725,19 @@ class GenerationTesterMixin: ...@@ -1131,27 +725,19 @@ class GenerationTesterMixin:
# disable cache # disable cache
config.use_cache = False config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_beam_kwargs()
output_beam_sample, output_generate = self._beam_sample_generate( output_generate = self._beam_sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
...@@ -1161,27 +747,18 @@ class GenerationTesterMixin: ...@@ -1161,27 +747,18 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self.assertTrue( self._check_outputs(
torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
) )
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_sample, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config() config, _, _, max_length = self._get_input_ids_and_config()
...@@ -1190,6 +767,10 @@ class GenerationTesterMixin: ...@@ -1190,6 +767,10 @@ class GenerationTesterMixin:
if config.bos_token_id is None: if config.bos_token_id is None:
return return
# hack in case they are equal, otherwise the attn mask will be [0]
if config.bos_token_id == config.pad_token_id:
config.pad_token_id = None
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() model.eval()
...@@ -1201,94 +782,65 @@ class GenerationTesterMixin: ...@@ -1201,94 +782,65 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
diversity_penalty=2.0,
) )
# check `generate()` and `group_beam_search()` are equal # check `generate()` and `group_beam_search()` are equal
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) beam_kwargs = self._get_diverse_beam_kwargs()
output_generate, output_group_beam_search = self._group_beam_search_generate( output_generate = self._group_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
) )
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
# check `generate()` and `group_beam_search()` are equal for `num_return_sequences` # check `group_beam_search` for higher than 1 `num_return_sequences`
num_return_sequences = 2 num_return_sequences = 2
if model.config.is_encoder_decoder: beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
max_length = 4 output_generate = self._group_beam_search_generate(
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
) )
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
def test_group_beam_search_generate_dict_output(self): def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
diversity_penalty=2.0,
) )
num_return_sequences = 1 beam_kwargs = self._get_diverse_beam_kwargs()
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( output_generate = self._group_beam_search_generate(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
...@@ -1297,31 +849,18 @@ class GenerationTesterMixin: ...@@ -1297,31 +849,18 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self.assertTrue( self._check_outputs(
torch.allclose( output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3
)
) )
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_group_beam_search, output_generate):
self._check_outputs(
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
# TODO: @gante # TODO: @gante
@is_flaky() @is_flaky()
...@@ -1329,24 +868,16 @@ class GenerationTesterMixin: ...@@ -1329,24 +868,16 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
max_length = 20 max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
) )
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints # Sample constraints
min_id = 3 min_id = 3
max_id = config.vocab_size max_id = config.vocab_size
...@@ -1356,50 +887,40 @@ class GenerationTesterMixin: ...@@ -1356,50 +887,40 @@ class GenerationTesterMixin:
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( beam_kwargs = self._get_constrained_beam_kwargs()
input_ids.shape[0], max_length, constraints, num_return_sequences=1 output_generate = self._constrained_beam_search_generate(
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints, constraints=constraints,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
) )
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
for generation_output in output_generate: for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output) self._check_sequence_inside_sequence(force_tokens, generation_output)
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` # check`constrained_beam_search` for higher than 1 `num_return_sequences`
# Sample constraints # Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
num_return_sequences = 2
max_length = 20 max_length = 20
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( output_generate = self._constrained_beam_search_generate(
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints, constraints=constraints,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
) )
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
for generation_output in output_generate: for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output) self._check_sequence_inside_sequence(force_tokens, generation_output)
...@@ -1411,19 +932,12 @@ class GenerationTesterMixin: ...@@ -1411,19 +932,12 @@ class GenerationTesterMixin:
# disable cache # disable cache
config.use_cache = False config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 20 max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id, config.forced_bos_token_id,
config.forced_eos_token_id, config.forced_eos_token_id,
max_length, max_length,
...@@ -1437,18 +951,14 @@ class GenerationTesterMixin: ...@@ -1437,18 +951,14 @@ class GenerationTesterMixin:
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( beam_kwargs = self._get_constrained_beam_kwargs()
input_ids.shape[0], max_length, constraints, num_return_sequences=1 output_generate = self._constrained_beam_search_generate(
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints, constraints=constraints,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
...@@ -1458,30 +968,20 @@ class GenerationTesterMixin: ...@@ -1458,30 +968,20 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self.assertTrue( self._check_outputs(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
) )
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_contrastive_generate(self): def test_contrastive_generate(self):
# check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
...@@ -1497,10 +997,10 @@ class GenerationTesterMixin: ...@@ -1497,10 +997,10 @@ class GenerationTesterMixin:
# test old generation output for backwards compatibility # test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_contrastive, output_generate = self._contrastive_generate( output_generate = self._contrastive_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
) )
self.assertListEqual(output_contrastive.tolist(), output_generate.tolist()) self.assertTrue(output_generate.shape[-1] == max_length)
def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1508,7 +1008,6 @@ class GenerationTesterMixin: ...@@ -1508,7 +1008,6 @@ class GenerationTesterMixin:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
...@@ -1518,7 +1017,7 @@ class GenerationTesterMixin: ...@@ -1518,7 +1017,7 @@ class GenerationTesterMixin:
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_contrastive, output_generate = self._contrastive_generate( output_generate = self._contrastive_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1530,10 +1029,8 @@ class GenerationTesterMixin: ...@@ -1530,10 +1029,8 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertListEqual(output_generate.sequences.tolist(), output_contrastive.sequences.tolist()) self.assertTrue(output_generate.sequences.shape[-1] == max_length)
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_contrastive_generate_low_memory(self): def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
...@@ -1591,7 +1088,7 @@ class GenerationTesterMixin: ...@@ -1591,7 +1088,7 @@ class GenerationTesterMixin:
] ]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest("May fix in the future: need model-specific fixes")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2) config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output # batch_size=1 is ok, but batch_size>1 will cause non-identical output
config.use_cache = True config.use_cache = True
...@@ -2455,220 +1952,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2455,220 +1952,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
], ],
) )
def test_max_length_backward_compat_greedy(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
with self.assertWarns(UserWarning):
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_sample(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
with torch.no_grad():
with self.assertWarns(UserWarning):
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 2
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
with self.assertWarns(UserWarning):
_ = bart_model.beam_search(
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
def test_max_length_backward_compat_group_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria & group beam search
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 6
num_beam_groups = 3
num_return_sequences = num_beams * batch_size
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
with self.assertWarns(UserWarning):
bart_model.group_beam_search(
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
def test_max_length_warning_if_different(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 6
num_beam_groups = 3
num_return_sequences = num_beams * batch_size
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
# Greedy
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
with self.assertWarns(UserWarning):
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
stopping_criteria=stopping_criteria,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
# Sample
with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
max_length=max_length,
beam_scorer=beam_scorer,
**model_kwargs,
)
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
with self.assertWarns(UserWarning):
bart_model.group_beam_search(
input_ids,
diverse_beam_scorer,
stopping_criteria=stopping_criteria,
num_beams=num_beams,
max_length=max_length,
**model_kwargs,
)
def test_max_length_if_input_embeds(self): def test_max_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria # PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris." article = "Today a dragon flew over Paris."
...@@ -2819,31 +2102,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2819,31 +2102,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# lets run beam search using 3 beams # lets run beam search using 3 beams
num_beams = 3 num_beams = 3
# define decoder start token ids # define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments # add encoder_outputs to model keyword arguments
model_kwargs = { model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
# instantiate beam scorer outputs = model.generate(
beam_scorer = BeamSearchScorer( input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs
batch_size=1,
num_beams=num_beams,
device=model.device,
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
) )
outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alt bist du?"]) self.assertListEqual(outputs, ["Wie alt bist du?"])
...@@ -3042,34 +2309,22 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3042,34 +2309,22 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# lets run beam search using 5 beams # lets run beam search using 5 beams
num_beams = 5 num_beams = 5
# define decoder start token ids # define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments # add encoder_outputs to model keyword arguments
model_kwargs = { model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
constraint_str = "sind" constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
# instantiate beam scorer
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)
outputs = model.constrained_beam_search( outputs = model.generate(
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs input_ids,
num_beams=num_beams,
force_words_ids=[constraint_token_ids],
min_length=5,
eos_token_id=model.config.eos_token_id,
**model_kwargs,
) )
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
......
...@@ -55,8 +55,6 @@ if is_torch_available(): ...@@ -55,8 +55,6 @@ if is_torch_available():
from transformers.generation import ( from transformers.generation import (
GenerateDecoderOnlyOutput, GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput, GenerateEncoderDecoderOutput,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
) )
...@@ -247,19 +245,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -247,19 +245,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask, max_length
@staticmethod @staticmethod
def _get_logits_processor_and_kwargs( def _get_logits_processor_and_warper_kwargs(
input_length, input_length,
eos_token_id,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None, max_length=None,
diversity_penalty=None,
): ):
process_kwargs = { process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1, "min_length": input_length + 1 if max_length is None else max_length - 1,
} }
logits_processor = LogitsProcessorList() warper_kwargs = {}
return process_kwargs, logits_processor return process_kwargs, warper_kwargs
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former # additional post-processing in the former
...@@ -269,7 +265,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -269,7 +265,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
...@@ -280,9 +276,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -280,9 +276,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
...@@ -295,7 +289,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -295,7 +289,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
...@@ -306,7 +300,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -306,7 +300,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
...@@ -316,28 +309,21 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -316,28 +309,21 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
) )
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
...@@ -349,23 +335,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -349,23 +335,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
output_scores=True, output_scores=True,
...@@ -374,7 +354,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -374,7 +354,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
...@@ -382,7 +361,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -382,7 +361,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
...@@ -393,7 +372,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -393,7 +372,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
...@@ -834,10 +812,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -834,10 +812,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens # generate max 3 tokens
decoder_input_ids = inputs_dict["decoder_input_ids"] max_length = 3
max_length = decoder_input_ids.shape[-1] + 3 return config, input_ids, attention_mask, max_length
decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :]
return config, input_ids, attention_mask, decoder_input_ids, max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes) # different modalities -> different shapes)
...@@ -846,18 +822,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -846,18 +822,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
decoder_input_ids,
max_length, max_length,
output_scores=False, output_scores=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
...@@ -876,28 +848,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -876,28 +848,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
**model_kwargs, **model_kwargs,
) )
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( return output_generate
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
decoder_input_ids,
max_length=max_length,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_greedy, output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes) # different modalities -> different shapes)
...@@ -906,11 +857,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -906,11 +857,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
decoder_input_ids,
max_length, max_length,
num_return_sequences, num_return_sequences,
logits_processor,
logits_warper,
logits_warper_kwargs, logits_warper_kwargs,
process_kwargs, process_kwargs,
output_scores=False, output_scores=False,
...@@ -936,62 +884,31 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -936,62 +884,31 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
**model_kwargs, **model_kwargs,
) )
torch.manual_seed(0) return output_generate
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
decoder_input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_sample, output_generate
@staticmethod @staticmethod
def _get_logits_processor_and_kwargs( def _get_logits_processor_and_warper_kwargs(
input_length, input_length,
eos_token_id,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None, max_length=None,
diversity_penalty=None,
): ):
process_kwargs = { process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1, "min_length": input_length + 1 if max_length is None else max_length - 1,
} }
logits_processor = LogitsProcessorList() warper_kwargs = {}
return process_kwargs, logits_processor return process_kwargs, warper_kwargs
def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -999,7 +916,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -999,7 +916,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
...@@ -1007,16 +923,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1007,16 +923,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_dict_outputs_use_cache(self): def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# enable cache # enable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -1024,64 +939,48 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1024,64 +939,48 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self): def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
) )
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self): def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1], input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs, logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs, process_kwargs=process_kwargs,
output_scores=True, output_scores=True,
...@@ -1090,11 +989,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1090,11 +989,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config() config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None # if no bos token id => cannot generate from None
if config.bos_token_id is None: if config.bos_token_id is None:
...@@ -1123,15 +1021,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1123,15 +1021,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -1139,7 +1036,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1139,7 +1036,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment