Unverified Commit b439129e authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[VITS] Add to TTA pipeline (#25906)



* [VITS] Add to TTA pipeline

* Update tests/pipelines/test_pipelines_text_to_audio.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* remove extra spaces

---------
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>
parent be0e189b
...@@ -1036,6 +1036,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( ...@@ -1036,6 +1036,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
# Model for Text-To-Waveform mapping # Model for Text-To-Waveform mapping
("bark", "BarkModel"), ("bark", "BarkModel"),
("musicgen", "MusicgenForConditionalGeneration"), ("musicgen", "MusicgenForConditionalGeneration"),
("vits", "VitsModel"),
] ]
) )
......
...@@ -56,8 +56,6 @@ class TextToAudioPipeline(Pipeline): ...@@ -56,8 +56,6 @@ class TextToAudioPipeline(Pipeline):
if self.framework == "tf": if self.framework == "tf":
raise ValueError("The TextToAudioPipeline is only available in PyTorch.") raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
self.forward_method = self.model.generate if self.model.can_generate() else self.model
self.vocoder = None self.vocoder = None
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
self.vocoder = ( self.vocoder = (
...@@ -110,8 +108,10 @@ class TextToAudioPipeline(Pipeline): ...@@ -110,8 +108,10 @@ class TextToAudioPipeline(Pipeline):
# we expect some kwargs to be additional tensors which need to be on the right device # we expect some kwargs to be additional tensors which need to be on the right device
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device) kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
# call the generate by defaults or the forward method if the model cannot generate if self.model.can_generate():
output = self.forward_method(**model_inputs, **kwargs) output = self.model.generate(**model_inputs, **kwargs)
else:
output = self.model(**model_inputs, **kwargs)[0]
if self.vocoder is not None: if self.vocoder is not None:
# in that case, the output is a spectrogram that needs to be converted into a waveform # in that case, the output is a spectrogram that needs to be converted into a waveform
......
...@@ -37,7 +37,7 @@ from .test_pipelines_common import ANY ...@@ -37,7 +37,7 @@ from .test_pipelines_common import ANY
@require_torch_or_tf @require_torch_or_tf
class TextToAudioPipelineTests(unittest.TestCase): class TextToAudioPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
# for now only text_to_waveform and not text_to_spectrogram # for now only test text_to_waveform and not text_to_spectrogram
@slow @slow
@require_torch @require_torch
...@@ -50,26 +50,21 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -50,26 +50,21 @@ class TextToAudioPipelineTests(unittest.TestCase):
} }
outputs = speech_generator("This is a test", forward_params=forward_params) outputs = speech_generator("This is a test", forward_params=forward_params)
# musicgen sampling_rate is not straightforward to get # musicgen sampling_rate is not straightforward to get
self.assertIsNone(outputs["sampling_rate"]) self.assertIsNone(outputs["sampling_rate"])
audio = outputs["audio"] audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio) self.assertEqual(ANY(np.ndarray), audio)
# test two examples side-by-side # test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching # test batching
outputs = speech_generator( outputs = speech_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
) )
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
@slow @slow
...@@ -77,8 +72,6 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -77,8 +72,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
def test_large_model_pt(self): def test_large_model_pt(self):
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
# test text-to-speech
forward_params = { forward_params = {
# Using `do_sample=False` to force deterministic output # Using `do_sample=False` to force deterministic output
"do_sample": False, "do_sample": False,
...@@ -86,7 +79,6 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -86,7 +79,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
} }
outputs = speech_generator("This is a test", forward_params=forward_params) outputs = speech_generator("This is a test", forward_params=forward_params)
self.assertEqual( self.assertEqual(
{"audio": ANY(np.ndarray), "sampling_rate": 24000}, {"audio": ANY(np.ndarray), "sampling_rate": 24000},
outputs, outputs,
...@@ -97,13 +89,10 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -97,13 +89,10 @@ class TextToAudioPipelineTests(unittest.TestCase):
["This is a test", "This is a second test"], ["This is a test", "This is a second test"],
forward_params=forward_params, forward_params=forward_params,
) )
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test other generation strategy # test other generation strategy
forward_params = { forward_params = {
"do_sample": True, "do_sample": True,
"semantic_max_new_tokens": 100, "semantic_max_new_tokens": 100,
...@@ -111,9 +100,7 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -111,9 +100,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
} }
outputs = speech_generator("This is a test", forward_params=forward_params) outputs = speech_generator("This is a test", forward_params=forward_params)
audio = outputs["audio"] audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio) self.assertEqual(ANY(np.ndarray), audio)
# test using a speaker embedding # test using a speaker embedding
...@@ -127,9 +114,7 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -127,9 +114,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
forward_params=forward_params, forward_params=forward_params,
batch_size=2, batch_size=2,
) )
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
...@@ -151,7 +136,6 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -151,7 +136,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
"return_token_type_ids": False, "return_token_type_ids": False,
"padding": "max_length", "padding": "max_length",
} }
outputs = speech_generator( outputs = speech_generator(
"This is a test", "This is a test",
forward_params=forward_params, forward_params=forward_params,
...@@ -163,28 +147,44 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -163,28 +147,44 @@ class TextToAudioPipelineTests(unittest.TestCase):
forward_params["history_prompt"] = history_prompt forward_params["history_prompt"] = history_prompt
# history_prompt is a torch.Tensor passed as a forward_param # history_prompt is a torch.Tensor passed as a forward_param
# if generation is successfull, it means that it was passed to the right device # if generation is successful, it means that it was passed to the right device
outputs = speech_generator( outputs = speech_generator(
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params "This is a test", forward_params=forward_params, preprocess_params=preprocess_params
) )
self.assertEqual( self.assertEqual(
{"audio": ANY(np.ndarray), "sampling_rate": 24000}, {"audio": ANY(np.ndarray), "sampling_rate": 24000},
outputs, outputs,
) )
@slow
@require_torch
def test_vits_model_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/mms-tts-eng", framework="pt")
outputs = speech_generator("This is a test")
self.assertEqual(outputs["sampling_rate"], 16000)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"])
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor):
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer) speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
return speech_generator, ["This is a test", "Another test"] return speech_generator, ["This is a test", "Another test"]
def run_pipeline_test(self, speech_generator, _): def run_pipeline_test(self, speech_generator, _):
outputs = speech_generator("This is a test") outputs = speech_generator("This is a test")
self.assertEqual(ANY(np.ndarray), outputs["audio"]) self.assertEqual(ANY(np.ndarray), outputs["audio"])
forward_params = {"num_return_sequences": 2, "do_sample": True} forward_params = {"num_return_sequences": 2, "do_sample": True}
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params) outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment