"vscode:/vscode.git/clone" did not exist on "a6271967c971cab5d882bba87c19406f30cdc7e4"
Unverified Commit db611aab authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

🚨 🚨 Raise error when no speaker embeddings in speecht5._generate_speech (#26418)

* add warning when no speaker embeddings in speecht5._generate_speech

* modify warning to error

* adapt generation test
parent 41c42f85
......@@ -2550,6 +2550,14 @@ def _generate_speech(
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
if speaker_embeddings is None:
raise ValueError(
"""`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
the code snippet provided in this link:
https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
"""
)
encoder_attention_mask = torch.ones_like(input_values)
encoder_out = model.speecht5.encoder(
......
......@@ -1015,15 +1015,21 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
set_seed(555) # make deterministic
speaker_embeddings = torch.zeros((1, 512)).to(torch_device)
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
set_seed(555) # make deterministic
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
generated_speech_with_generate = model.generate(
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
)
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
@require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment