Unverified Commit 0db5d911 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `SpeechT5ForSpeechToSpeechIntegrationTests` device issue (#21460)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 59d5edef
...@@ -2869,7 +2869,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): ...@@ -2869,7 +2869,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform.
""" """
if speaker_embeddings is None: if speaker_embeddings is None:
speaker_embeddings = torch.zeros((1, 512)) speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
return _generate_speech( return _generate_speech(
self, self,
......
...@@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase): ...@@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device) input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
speaker_embeddings = torch.zeros((1, 512)) speaker_embeddings = torch.zeros((1, 512), device=torch_device)
generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings) generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins) self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment