Unverified Commit 44a0490d authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[MusicGen] Add sampling rate to config (#26136)



* [MusicGen] Add sampling rate to config

* remove tiny

* make property

* Update tests/pipelines/test_pipelines_text_to_audio.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* style

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 8881f38a
...@@ -226,3 +226,8 @@ class MusicgenConfig(PretrainedConfig): ...@@ -226,3 +226,8 @@ class MusicgenConfig(PretrainedConfig):
decoder=decoder_config.to_dict(), decoder=decoder_config.to_dict(),
**kwargs, **kwargs,
) )
@property
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
...@@ -41,35 +41,32 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -41,35 +41,32 @@ class TextToAudioPipelineTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_small_model_pt(self): def test_small_musicgen_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
forward_params = { forward_params = {
"do_sample": False, "do_sample": False,
"max_new_tokens": 250, "max_new_tokens": 250,
} }
outputs = speech_generator("This is a test", forward_params=forward_params) outputs = music_generator("This is a test", forward_params=forward_params)
# musicgen sampling_rate is not straightforward to get self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs)
self.assertIsNone(outputs["sampling_rate"])
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# test two examples side-by-side # test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) outputs = music_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching # test batching
outputs = speech_generator( outputs = music_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
) )
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
@require_torch @require_torch
def test_large_model_pt(self): def test_small_bark_pt(self):
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
forward_params = { forward_params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment