Unverified Commit 4309abed authored by Sihan Chen's avatar Sihan Chen Committed by GitHub
Browse files

Add speecht5 batch generation and fix wrong attention mask when padding (#25943)

* fix speecht5 wrong attention mask when padding

* enable batch generation and add parameter attention_mask

* fix doc

* fix format

* batch postnet inputs, return batched lengths, and consistent to old api

* fix format

* fix format

* fix the format

* fix doc-builder error

* add test, cross attention and docstring

* optimize code based on reviews

* docbuild

* refine

* not skip slow test

* add consistent dropout for batching

* loose atol

* add another test regarding to the consistency of vocoder

* fix format

* refactor

* add return_concrete_lengths as parameter for consistency w/wo batching

* fix review issues

* fix cross_attention issue
parent ee4fb326
......@@ -1026,14 +1026,21 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_sentencepiece
@require_tokenizers
@slow
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
@cached_property
def default_model(self):
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
@cached_property
def default_processor(self):
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
@cached_property
def default_vocoder(self):
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
def test_generation(self):
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
model = self.default_model
model.to(torch_device)
processor = self.default_processor
......@@ -1045,7 +1052,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))
set_seed(555) # make deterministic
......@@ -1053,7 +1060,76 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
generated_speech_with_generate = model.generate(
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
)
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
def test_batch_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor
vocoder = self.default_vocoder
set_seed(555) # make deterministic
input_text = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister quilter's manner less interesting than his matter",
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
]
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
spectrograms, spectrogram_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
return_output_lengths=True,
)
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
# Check waveform results are the same with or without using vocder
set_seed(555)
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=True,
)
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)
# Check waveform results are the same with return_concrete_lengths=True/False
set_seed(555)
waveforms_with_vocoder_no_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=False,
)
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))
# Check results when batching are consistent with results without batching
for i, text in enumerate(input_text):
set_seed(555) # make deterministic
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
spectrogram = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
)
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))
waveform = vocoder(spectrogram)
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
# Check whether waveforms are the same with/without passing vocoder
set_seed(555)
waveform_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
vocoder=vocoder,
)
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
@require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment