Unverified Commit c043ce6c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] correct encoder_outputs are passed without attention_mask (#14980)

* [Generate] correct encoder_outputs are passed without attention_mask

* Apply suggestions from code review

* up
parent a1392883
......@@ -1019,8 +1019,10 @@ class GenerationMixin:
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
......
......@@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_encoder_outputs_attention_mask(self):
input_values = floats_tensor((2, 250)).to(torch_device)
attention_mask = torch.ones_like(input_values)
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
encoder = model.get_encoder()
encoder_outputs = encoder(input_values)
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
output_sequences_with_mask = output_sequences_with_mask.cpu()
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
......@@ -215,3 +215,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
@slow
@require_torch
@require_torchaudio
def test_speech_to_text_leveraged(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="patrickvonplaten/wav2vec2-2-bart-base",
feature_extractor="patrickvonplaten/wav2vec2-2-bart-base",
tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"),
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment