Unverified Commit 6d3f9c1e authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

add generate method to SpeechT5ForTextToSpeech (#25233)



* add generate method to SpeechT5ForTextToSpeech

* update speecht5forTTS docstrings

* Remove defaults to None in generate docstrings
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8455346c
...@@ -71,7 +71,7 @@ This model was contributed by [Matthijs](https://huggingface.co/Matthijs). The o ...@@ -71,7 +71,7 @@ This model was contributed by [Matthijs](https://huggingface.co/Matthijs). The o
[[autodoc]] SpeechT5ForTextToSpeech [[autodoc]] SpeechT5ForTextToSpeech
- forward - forward
- generate_speech - generate
## SpeechT5ForSpeechToSpeech ## SpeechT5ForSpeechToSpeech
......
...@@ -2717,7 +2717,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): ...@@ -2717,7 +2717,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
>>> set_seed(555) # make deterministic >>> set_seed(555) # make deterministic
>>> # generate speech >>> # generate speech
>>> speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) >>> speech = model.generate(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
>>> speech.shape >>> speech.shape
torch.Size([15872]) torch.Size([15872])
``` ```
...@@ -2783,6 +2783,65 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): ...@@ -2783,6 +2783,65 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
r"""
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
speech waveform using a vocoder.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
[`~PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
threshold (`float`, *optional*, defaults to 0.5):
The generated sequence ends when the predicted stop token probability exceeds this value.
minlenratio (`float`, *optional*, defaults to 0.0):
Used to calculate the minimum required length for the output sequence.
maxlenratio (`float`, *optional*, defaults to 20.0):
Used to calculate the maximum allowed length for the output sequence.
vocoder (`nn.Module`, *optional*):
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
return _generate_speech(
self,
input_ids,
speaker_embeddings,
threshold,
minlenratio,
maxlenratio,
vocoder,
output_cross_attentions,
)
@torch.no_grad() @torch.no_grad()
def generate_speech( def generate_speech(
self, self,
......
...@@ -1020,6 +1020,10 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): ...@@ -1020,6 +1020,10 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
generated_speech = model.generate_speech(input_ids) generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins)) self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
@require_torch @require_torch
class SpeechT5ForSpeechToSpeechTester: class SpeechT5ForSpeechToSpeechTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment