Unverified Commit 79d62b2d authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

if output is tuple like facebook/hf-seamless-m4t-medium, waveform is … (#29722)



* if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>

* add test and fix batch issue
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>

* add dict output support for seamless_m4t
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>

---------
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 8b52fa6b
...@@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel): ...@@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
self.device self.device
) )
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
# second generation # second generation
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
output_unit_ids = unit_ids.detach().clone() output_unit_ids = unit_ids.detach().clone()
......
...@@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset): ...@@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset):
# Try to infer the size of the batch # Try to infer the size of the batch
if isinstance(processed, torch.Tensor): if isinstance(processed, torch.Tensor):
first_tensor = processed first_tensor = processed
elif isinstance(processed, tuple):
first_tensor = processed[0]
else: else:
key = list(processed.keys())[0] key = list(processed.keys())[0]
first_tensor = processed[key] first_tensor = processed[key]
if isinstance(first_tensor, list): if isinstance(first_tensor, list):
observed_batch_size = len(first_tensor) observed_batch_size = len(first_tensor)
else: else:
...@@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset): ...@@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset):
# elements. # elements.
self.loader_batch_size = observed_batch_size self.loader_batch_size = observed_batch_size
# Setting internal index to unwrap the batch # Setting internal index to unwrap the batch
self._loader_batch_data = processed self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
self._loader_batch_index = 0 self._loader_batch_index = 0
return self.loader_batch_item() return self.loader_batch_item()
else: else:
......
...@@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline): ...@@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline):
def postprocess(self, waveform): def postprocess(self, waveform):
output_dict = {} output_dict = {}
if isinstance(waveform, dict):
waveform = waveform["waveform"]
elif isinstance(waveform, tuple):
waveform = waveform[0]
output_dict["audio"] = waveform.cpu().float().numpy() output_dict["audio"] = waveform.cpu().float().numpy()
output_dict["sampling_rate"] = self.sampling_rate output_dict["sampling_rate"] = self.sampling_rate
......
...@@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase):
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow
@require_torch
def test_medium_seamless_m4t_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]:
outputs = speech_generator("This is a test", forward_params=forward_params)
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
# test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching
outputs = speech_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
@require_torch @require_torch
def test_small_bark_pt(self): def test_small_bark_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment