Unverified Commit 82e61f34 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[SpeechT5HifiGan] Handle batched inputs (#21702)

* [SpeechT5HifiGan] Handle batched inputs

* fix docstring

* rebase and new ruff style
parent 09127c57
...@@ -3030,19 +3030,27 @@ class SpeechT5HifiGan(PreTrainedModel): ...@@ -3030,19 +3030,27 @@ class SpeechT5HifiGan(PreTrainedModel):
def forward(self, spectrogram): def forward(self, spectrogram):
r""" r"""
Converts a single log-mel spectogram into a speech waveform. Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
waveform.
Args: Args:
spectrogram (`torch.FloatTensor` of shape `(sequence_length, config.model_in_dim)`): spectrogram (`torch.FloatTensor`):
Tensor containing the log-mel spectrogram. Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
Returns: Returns:
`torch.FloatTensor`: Tensor of shape `(num_frames,)` containing the speech waveform. `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
""" """
if self.config.normalize_before: if self.config.normalize_before:
spectrogram = (spectrogram - self.mean) / self.scale spectrogram = (spectrogram - self.mean) / self.scale
hidden_states = spectrogram.transpose(1, 0).unsqueeze(0) is_batched = spectrogram.dim() == 3
if not is_batched:
spectrogram = spectrogram.unsqueeze(0)
hidden_states = spectrogram.transpose(2, 1)
hidden_states = self.conv_pre(hidden_states) hidden_states = self.conv_pre(hidden_states)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
...@@ -3058,5 +3066,11 @@ class SpeechT5HifiGan(PreTrainedModel): ...@@ -3058,5 +3066,11 @@ class SpeechT5HifiGan(PreTrainedModel):
hidden_states = self.conv_post(hidden_states) hidden_states = self.conv_post(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) if not is_batched:
# remove batch dim and collapse tensor to 1-d audio waveform
waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
else:
# remove seq-len dim since this collapses to 1
waveform = hidden_states.squeeze(1)
return waveform return waveform
...@@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase): ...@@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
# skip because it fails on automapping of SpeechT5HifiGanConfig # skip because it fails on automapping of SpeechT5HifiGanConfig
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
pass pass
def test_batched_inputs_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
batched_outputs = model(batched_inputs)
self.assertEqual(
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
)
def test_unbatched_inputs_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs = model(inputs["spectrogram"])
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment