Unverified Commit 56a76082 authored by Nguyễn Công Tú Anh's avatar Nguyễn Công Tú Anh Committed by GitHub
Browse files

Add AudioLDM2 TTS (#5381)



* add audioldm2 tts

* change gpt2 max new tokens

* remove unnecessary pipeline and class

* add TTS to AudioLDM2Pipeline

* add TTS docs

* delete unnecessary file

* remove unnecessary import

* add audioldm2 slow testcase

* fix code quality

* remove AudioLDMLearnablePositionalEmbedding

* add variable check vits encoder

* add use_learned_position_embedding

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 6133d98f
...@@ -20,7 +20,8 @@ The abstract of the paper is the following: ...@@ -20,7 +20,8 @@ The abstract of the paper is the following:
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).* *Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2). This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi) and [Nguyễn Công Tú Anh](https://github.com/tuanh123789). The original codebase can be
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
## Tips ## Tips
...@@ -36,6 +37,8 @@ See table below for details on the three checkpoints: ...@@ -36,6 +37,8 @@ See table below for details on the three checkpoints:
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k | | [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k | | [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k | | [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
| [audioldm2-gigaspeech](https://huggingface.co/anhnct/audioldm2_gigaspeech) | Text-to-speech | 350M | 1.1B |10k |
| [audioldm2-ljspeech](https://huggingface.co/anhnct/audioldm2_ljspeech) | Text-to-speech | 350M | 1.1B | |
### Constructing a prompt ### Constructing a prompt
...@@ -53,7 +56,7 @@ See table below for details on the three checkpoints: ...@@ -53,7 +56,7 @@ See table below for details on the three checkpoints:
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation. * The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly. * Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example). The following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
<Tip> <Tip>
......
...@@ -95,7 +95,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): ...@@ -95,7 +95,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
""" """
@register_to_config @register_to_config
def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim): def __init__(
self,
text_encoder_dim,
text_encoder_1_dim,
langauge_model_dim,
use_learned_position_embedding=None,
max_seq_length=None,
):
super().__init__() super().__init__()
# additional projection layers for each text encoder # additional projection layers for each text encoder
self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
...@@ -108,6 +115,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): ...@@ -108,6 +115,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
self.use_learned_position_embedding = use_learned_position_embedding
# learable positional embedding for vits encoder
if self.use_learned_position_embedding is not None:
self.learnable_positional_embedding = torch.nn.Parameter(
torch.zeros((1, text_encoder_1_dim, max_seq_length))
)
def forward( def forward(
self, self,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
...@@ -120,6 +135,10 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): ...@@ -120,6 +135,10 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
) )
# Add positional embedding for Vits hidden state
if self.use_learned_position_embedding is not None:
hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
hidden_states_1 = self.projection_1(hidden_states_1) hidden_states_1 = self.projection_1(hidden_states_1)
hidden_states_1, attention_mask_1 = add_special_tokens( hidden_states_1, attention_mask_1 = add_special_tokens(
hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
......
...@@ -27,6 +27,8 @@ from transformers import ( ...@@ -27,6 +27,8 @@ from transformers import (
T5EncoderModel, T5EncoderModel,
T5Tokenizer, T5Tokenizer,
T5TokenizerFast, T5TokenizerFast,
VitsModel,
VitsTokenizer,
) )
from ...models import AutoencoderKL from ...models import AutoencoderKL
...@@ -79,6 +81,37 @@ EXAMPLE_DOC_STRING = """ ...@@ -79,6 +81,37 @@ EXAMPLE_DOC_STRING = """
>>> # save the best audio sample (index 0) as a .wav file >>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0]) >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
``` ```
```
#Using AudioLDM2 for Text To Speech
>>> import scipy
>>> import torch
>>> from diffusers import AudioLDM2Pipeline
>>> repo_id = "anhnct/audioldm2_gigaspeech"
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> # define the prompts
>>> prompt = "A female reporter is speaking"
>>> transcript = "wish you have a good day"
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # run the generation
>>> audio = pipe(
... prompt,
... transcription=transcript,
... num_inference_steps=200,
... audio_length_in_s=10.0,
... num_waveforms_per_prompt=2,
... generator=generator,
... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS
... ).audios
>>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])
```
""" """
...@@ -116,20 +149,23 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -116,20 +149,23 @@ class AudioLDM2Pipeline(DiffusionPipeline):
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
rank generated waveforms against the text prompt by computing similarity scores. rank generated waveforms against the text prompt by computing similarity scores.
text_encoder_2 ([`~transformers.T5EncoderModel`]): text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
Second frozen text-encoder. AudioLDM2 uses the encoder of Second frozen text-encoder. AudioLDM2 uses the encoder of
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use
for TTS. AudioLDM2 uses the encoder of
[Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel).
projection_model ([`AudioLDM2ProjectionModel`]): projection_model ([`AudioLDM2ProjectionModel`]):
A trained model used to linearly project the hidden-states from the first and second text encoder models A trained model used to linearly project the hidden-states from the first and second text encoder models
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
concatenated to give the input to the language model. concatenated to give the input to the language model. A Learned Position Embedding for the Vits
hidden-states
language_model ([`~transformers.GPT2Model`]): language_model ([`~transformers.GPT2Model`]):
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
outputs from the two text encoders. outputs from the two text encoders.
tokenizer ([`~transformers.RobertaTokenizer`]): tokenizer ([`~transformers.RobertaTokenizer`]):
Tokenizer to tokenize text for the first frozen text-encoder. Tokenizer to tokenize text for the first frozen text-encoder.
tokenizer_2 ([`~transformers.T5Tokenizer`]): tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
Tokenizer to tokenize text for the second frozen text-encoder. Tokenizer to tokenize text for the second frozen text-encoder.
feature_extractor ([`~transformers.ClapFeatureExtractor`]): feature_extractor ([`~transformers.ClapFeatureExtractor`]):
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring. Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
...@@ -146,11 +182,11 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -146,11 +182,11 @@ class AudioLDM2Pipeline(DiffusionPipeline):
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: ClapModel, text_encoder: ClapModel,
text_encoder_2: T5EncoderModel, text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel, projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model, language_model: GPT2Model,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast], tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor, feature_extractor: ClapFeatureExtractor,
unet: AudioLDM2UNet2DConditionModel, unet: AudioLDM2UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
...@@ -273,6 +309,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -273,6 +309,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
device, device,
num_waveforms_per_prompt, num_waveforms_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
transcription=None,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -288,6 +325,8 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -288,6 +325,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
transcription (`str` or `List[str]`):
transcription of text to speech
device (`torch.device`): device (`torch.device`):
torch device torch device
num_waveforms_per_prompt (`int`): num_waveforms_per_prompt (`int`):
...@@ -368,16 +407,26 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -368,16 +407,26 @@ class AudioLDM2Pipeline(DiffusionPipeline):
# Define tokenizers and text encoders # Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2] tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2] is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel)
if is_vits_text_encoder:
text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder]
else:
text_encoders = [self.text_encoder, self.text_encoder_2]
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds_list = [] prompt_embeds_list = []
attention_mask_list = [] attention_mask_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for tokenizer, text_encoder in zip(tokenizers, text_encoders):
use_prompt = isinstance(
tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast)
)
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt if use_prompt else transcription,
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, padding="max_length"
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
else True,
max_length=tokenizer.model_max_length, max_length=tokenizer.model_max_length,
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
...@@ -407,6 +456,18 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -407,6 +456,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds[:, None, :] prompt_embeds = prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state # make sure that we attend to this single hidden-state
attention_mask = attention_mask.new_ones((batch_size, 1)) attention_mask = attention_mask.new_ones((batch_size, 1))
elif is_vits_text_encoder:
# Add end_token_id and attention mask in the end of sequence phonemes
for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask):
for idx, phoneme_id in enumerate(text_input_id):
if phoneme_id == 0:
text_input_id[idx] = 182
text_attention_mask[idx] = 1
break
prompt_embeds = text_encoder(
text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1)
)
prompt_embeds = prompt_embeds[0]
else: else:
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids, text_input_ids,
...@@ -485,7 +546,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -485,7 +546,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
max_length=tokenizer.model_max_length max_length=tokenizer.model_max_length
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
else max_length, else max_length,
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
...@@ -503,6 +564,15 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -503,6 +564,15 @@ class AudioLDM2Pipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds[:, None, :] negative_prompt_embeds = negative_prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state # make sure that we attend to this single hidden-state
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1)) negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
elif is_vits_text_encoder:
negative_prompt_embeds = torch.zeros(
batch_size,
tokenizer.model_max_length,
text_encoder.config.hidden_size,
).to(dtype=self.text_encoder_2.dtype, device=device)
negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to(
dtype=self.text_encoder_2.dtype, device=device
)
else: else:
negative_prompt_embeds = text_encoder( negative_prompt_embeds = text_encoder(
uncond_input_ids, uncond_input_ids,
...@@ -623,6 +693,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -623,6 +693,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
audio_length_in_s, audio_length_in_s,
vocoder_upsample_factor, vocoder_upsample_factor,
callback_steps, callback_steps,
transcription=None,
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
...@@ -690,6 +761,14 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -690,6 +761,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
) )
if transcription is None:
if self.text_encoder_2.config.model_type == "vits":
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list)
):
raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}")
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None: if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape: if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
raise ValueError( raise ValueError(
...@@ -734,6 +813,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -734,6 +813,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
transcription: Union[str, List[str]] = None,
audio_length_in_s: Optional[float] = None, audio_length_in_s: Optional[float] = None,
num_inference_steps: int = 200, num_inference_steps: int = 200,
guidance_scale: float = 3.5, guidance_scale: float = 3.5,
...@@ -761,6 +841,8 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -761,6 +841,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
transcription (`str` or `List[str]`, *optional*):\
The transcript for text to speech.
audio_length_in_s (`int`, *optional*, defaults to 10.24): audio_length_in_s (`int`, *optional*, defaults to 10.24):
The length of the generated audio sample in seconds. The length of the generated audio sample in seconds.
num_inference_steps (`int`, *optional*, defaults to 200): num_inference_steps (`int`, *optional*, defaults to 200):
...@@ -857,6 +939,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -857,6 +939,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
audio_length_in_s, audio_length_in_s,
vocoder_upsample_factor, vocoder_upsample_factor,
callback_steps, callback_steps,
transcription,
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
...@@ -886,6 +969,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -886,6 +969,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
device, device,
num_waveforms_per_prompt, num_waveforms_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
transcription,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
......
...@@ -516,6 +516,20 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -516,6 +516,20 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
} }
return inputs return inputs
def get_inputs_tts(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A men saying",
"transcription": "hello my name is John",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 2.5,
}
return inputs
def test_audioldm2(self): def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2") audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device) audioldm_pipe = audioldm_pipe.to(torch_device)
...@@ -572,3 +586,22 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -572,3 +586,22 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
) )
max_diff = np.abs(expected_slice - audio_slice).max() max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3 assert max_diff < 1e-3
def test_audioldm2_tts(self):
audioldm_tts_pipe = AudioLDM2Pipeline.from_pretrained("anhnct/audioldm2_gigaspeech")
audioldm_tts_pipe = audioldm_tts_pipe.to(torch_device)
audioldm_tts_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs_tts(torch_device)
audio = audioldm_tts_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81952
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[8825:8835]
expected_slice = np.array(
[-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment