Unverified Commit b1cd4874 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Do not remove half seq length in generation tests (#30016)



* remove seq length from generation tests

* style and quality

* [test_all] & PR suggestion
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* [test all] remove unused variables

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent b4fd49b6
This diff is collapsed.
...@@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
input_ids = input_ids[:batch_size, :sequence_length] input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length] attention_mask = attention_mask[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask
def setUp(self): def setUp(self):
self.model_tester = BigBirdPegasusModelTester(self) self.model_tester = BigBirdPegasusModelTester(self)
......
...@@ -457,6 +457,20 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -457,6 +457,20 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
], ],
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape
encoder_expected_shape = (
batch_size,
config.num_attention_heads,
seq_length,
self.model_tester.attention_window // 2 * 2 + 1,
)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
...@@ -752,7 +752,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -752,7 +752,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
block_len = getattr(self.model_tester, "block_len", None) block_len = getattr(self.model_tester, "block_len", None)
encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len) encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions], [layer_attentions.shape for layer_attentions in attentions],
...@@ -885,7 +885,7 @@ class LongT5TGlobalModelTest(LongT5ModelTest): ...@@ -885,7 +885,7 @@ class LongT5TGlobalModelTest(LongT5ModelTest):
global_seq_length = seq_length // global_block_size global_seq_length = seq_length // global_block_size
encoder_expected_shape = ( encoder_expected_shape = (
batch_size, batch_size,
1, 2,
config.num_attention_heads, config.num_attention_heads,
block_len, block_len,
3 * block_len + global_seq_length, 3 * block_len + global_seq_length,
......
...@@ -245,34 +245,28 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -245,34 +245,28 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
sequence_length = input_ids.shape[-1] sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :] input_ids = input_ids[: batch_size * config.num_codebooks, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask
@staticmethod @staticmethod
def _get_logits_processor_and_warper_kwargs( def _get_logits_processor_and_warper_kwargs(
input_length, input_length,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None,
): ):
process_kwargs = { process_kwargs = {}
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
warper_kwargs = {} warper_kwargs = {}
return process_kwargs, warper_kwargs return process_kwargs, warper_kwargs
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1327,9 +1321,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1327,9 +1321,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_ids = input_ids[:batch_size, :] input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens return config, input_ids, attention_mask
max_length = 3
return config, input_ids, attention_mask, max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes) # different modalities -> different shapes)
...@@ -1338,29 +1330,22 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1338,29 +1330,22 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
max_length,
output_scores=False, output_scores=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
max_length=max_length, max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True, remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -1373,10 +1358,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1373,10 +1358,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
max_length,
num_return_sequences, num_return_sequences,
logits_warper_kwargs,
process_kwargs,
output_scores=False, output_scores=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
...@@ -1388,15 +1370,13 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1388,15 +1370,13 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_ids, input_ids,
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
max_length=max_length, max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True, remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -1407,25 +1387,21 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1407,25 +1387,21 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_length, input_length,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None,
): ):
process_kwargs = { process_kwargs = {}
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
warper_kwargs = {} warper_kwargs = {}
return process_kwargs, warper_kwargs return process_kwargs, warper_kwargs
def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1439,7 +1415,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1439,7 +1415,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_dict_outputs_use_cache(self): def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1448,7 +1424,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1448,7 +1424,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1459,46 +1434,30 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1459,46 +1434,30 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_sample_generate(self): def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
) )
self.assertIsInstance(output_generate, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self): def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1508,7 +1467,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1508,7 +1467,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config() config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None # if no bos token id => cannot generate from None
if config.bos_token_id is None: if config.bos_token_id is None:
...@@ -1518,7 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1518,7 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() model.eval()
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate) self.assertIsNotNone(output_ids_generate)
@require_torch_fp16 @require_torch_fp16
...@@ -1537,7 +1498,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1537,7 +1498,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
...@@ -1545,7 +1506,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1545,7 +1506,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
......
...@@ -246,34 +246,28 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes ...@@ -246,34 +246,28 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
sequence_length = input_ids.shape[-1] sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :] input_ids = input_ids[: batch_size * config.num_codebooks, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask
@staticmethod @staticmethod
def _get_logits_processor_and_warper_kwargs( def _get_logits_processor_and_warper_kwargs(
input_length, input_length,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None,
): ):
process_kwargs = { process_kwargs = {}
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
warper_kwargs = {} warper_kwargs = {}
return process_kwargs, warper_kwargs return process_kwargs, warper_kwargs
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1309,9 +1303,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1309,9 +1303,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
input_ids = input_ids[:batch_size, :] input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens return config, input_ids, attention_mask
max_length = 3
return config, input_ids, attention_mask, max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
# different modalities -> different shapes) # different modalities -> different shapes)
...@@ -1320,29 +1312,22 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1320,29 +1312,22 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
max_length,
output_scores=False, output_scores=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
max_length=max_length, max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True, remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -1355,10 +1340,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1355,10 +1340,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
max_length,
num_return_sequences, num_return_sequences,
logits_warper_kwargs,
process_kwargs,
output_scores=False, output_scores=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
...@@ -1370,15 +1352,13 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1370,15 +1352,13 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
input_ids, input_ids,
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
max_length=max_length, max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True, remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -1389,25 +1369,21 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1389,25 +1369,21 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
input_length, input_length,
forced_bos_token_id=None, forced_bos_token_id=None,
forced_eos_token_id=None, forced_eos_token_id=None,
max_length=None,
): ):
process_kwargs = { process_kwargs = {}
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
warper_kwargs = {} warper_kwargs = {}
return process_kwargs, warper_kwargs return process_kwargs, warper_kwargs
def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate( output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1421,7 +1397,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1421,7 +1397,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_greedy_generate_dict_outputs_use_cache(self): def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1430,7 +1406,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1430,7 +1406,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1441,46 +1416,30 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1441,46 +1416,30 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_sample_generate(self): def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
) )
self.assertIsInstance(output_generate, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self): def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
# disable cache # disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False config.use_cache = False
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
output_generate = self._sample_generate( output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
...@@ -1490,7 +1449,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1490,7 +1449,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config() config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None # if no bos token id => cannot generate from None
if config.bos_token_id is None: if config.bos_token_id is None:
...@@ -1500,7 +1459,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1500,7 +1459,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() model.eval()
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate) self.assertIsNotNone(output_ids_generate)
@require_torch_fp16 @require_torch_fp16
...@@ -1519,7 +1480,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1519,7 +1480,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2 config.audio_channels = 2
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
...@@ -1527,7 +1488,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -1527,7 +1488,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model=model, model=model,
input_ids=input_ids.to(torch_device), input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device), attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=True,
......
...@@ -686,6 +686,18 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod ...@@ -686,6 +686,18 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
def test_left_padding_compatibility(self): def test_left_padding_compatibility(self):
pass pass
def _get_input_ids_and_config(self, batch_size=2):
# override because overwise we hit max possible seq length for model (4*8=32)
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
input_ids = input_ids[:batch_size, :16]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
config.eos_token_id = None
config.forced_eos_token_id = None
return config, input_ids, attention_mask
@require_torch @require_torch
class ReformerLSHAttnModelTest( class ReformerLSHAttnModelTest(
......
...@@ -285,7 +285,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest ...@@ -285,7 +285,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
input_name = "input_features" input_name = "input_features"
def _get_input_ids_and_config(self, batch_size=2): def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self) config, input_ids, attention_mask = GenerationTesterMixin._get_input_ids_and_config(self)
# `input_ids` is actually `input_features` which is a 3D tensor. # `input_ids` is actually `input_features` which is a 3D tensor.
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an # We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
...@@ -294,7 +294,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest ...@@ -294,7 +294,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
sequence_length = input_ids.shape[1] sequence_length = input_ids.shape[1]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device) attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask
def setUp(self): def setUp(self):
self.model_tester = Speech2TextModelTester(self) self.model_tester = Speech2TextModelTester(self)
......
...@@ -477,13 +477,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -477,13 +477,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# cut to half length & take max batch_size=batch_size # cut to half length & take max batch_size=batch_size
input_ids = input_ids[:batch_size, :, :] input_ids = input_ids[:batch_size, :, :]
# generate max 3 tokens
max_length = 4
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
return config, input_ids, None, max_length return config, input_ids, None
def test_inputs_embeds(self): def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -646,7 +646,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -646,7 +646,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
seq_len = 1 seq_len = 1
else: else:
# for first item dummy PAD token is appended so need one more # for first item dummy PAD token is appended so need one more
seq_len = (min_length + 1) if idx == 0 else min_length # else offset+dummy_token when using cache
seq_len = (min_length + 1) if idx == 0 else 3
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape) self.assertEqual(layer_hidden_states.shape, expected_shape)
...@@ -665,8 +666,11 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -665,8 +666,11 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tgt_len = min_length tgt_len = min_length
# for first item dummy PAD token is appended so need one more # for first item dummy PAD token is appended so need one more
# every token after consists of offset+dummy_token length when using cache
if idx == 0: if idx == 0:
tgt_len += 1 tgt_len += 1
else:
tgt_len = 3
src_len = min_length + idx + 1 src_len = min_length + idx + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment