Unverified Commit 24c5e770 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AudioLDM2] Doc fixes (#4739)



* [AudioLDM2] Doc fixes

* update docstrings

* fix unet docstring

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent cd21b965
......@@ -20,10 +20,10 @@ Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelin
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two
text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap)
and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings
are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/api/pipelines/audioldm2/AudioLDM2ProjectionModel).
are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/main/api/pipelines/audioldm2#diffusers.AudioLDM2ProjectionModel).
A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively
predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding
vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/api/pipelines/audioldm2/AudioLDM2UNet2DConditionModel)
vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2UNet2DConditionModel)
of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention
conditioning, as in most other LDMs.
......@@ -38,13 +38,17 @@ found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
### Choosing a checkpoint
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation. See table below for details on the three official checkpoints:
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio
generation. The third checkpoint is trained exclusively on text-to-music generation.
| Checkpoint | Task | Model Size | Training Data / h |
|-----------------------------------------------------------------|---------------|------------|-------------------|
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 1.1B | 1150k |
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 1.1B | 665k |
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 1.5B | 1150k |
All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet.
See table below for details on the three checkpoints:
| Checkpoint | Task | UNet Model Size | Total Model Size | Training Data / h |
|-----------------------------------------------------------------|---------------|-----------------|------------------|-------------------|
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
### Constructing a prompt
......@@ -62,37 +66,7 @@ AudioLDM2 comes in three variants. Two of these checkpoints are applicable to th
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
The following example demonstrates how to construct good music generation using the aforementioned tips:
```python
import scipy
import torch
from diffusers import AudioLDM2Pipeline
# load the best weights for music generation
repo_id = "cvssp/audioldm2-music"
pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# define the prompts
prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
negative_prompt = "Low quality."
# set the seed
generator = torch.Generator("cuda").manual_seed(0)
# run the generation
audio = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=200,
audio_length_in_s=10.0,
num_waveforms_per_prompt=3,
).audios
# save the best audio sample (index 0) as a .wav file
scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
```
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
<Tip>
......@@ -114,3 +88,6 @@ section to learn how to efficiently load the same components into multiple pipel
## AudioLDM2UNet2DConditionModel
[[autodoc]] AudioLDM2UNet2DConditionModel
- forward
## AudioPipelineOutput
[[autodoc]] pipelines.AudioPipelineOutput
\ No newline at end of file
......@@ -659,7 +659,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
encoder_attention_mask_1: Optional[torch.Tensor] = None,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
The [`AudioLDM2UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
......
......@@ -51,19 +51,33 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import AudioLDM2Pipeline
>>> import torch
>>> import scipy
>>> import torch
>>> from diffusers import AudioLDM2Pipeline
>>> repo_id = "cvssp/audioldm2"
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
>>> audio = pipe(prompt, num_inference_steps=200, audio_length_in_s=10.0).audios[0]
>>> # define the prompts
>>> prompt = "The sound of a hammer hitting a wooden surface."
>>> negative_prompt = "Low quality."
>>> # save the audio sample as a .wav file
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # run the generation
>>> audio = pipe(
... prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=200,
... audio_length_in_s=10.0,
... num_waveforms_per_prompt=3,
... generator=generator,
... ).audios
>>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
```
"""
......@@ -315,6 +329,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Example:
```python
>>> import scipy
>>> import torch
>>> from diffusers import AudioLDM2Pipeline
......@@ -337,6 +352,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
... num_inference_steps=200,
... audio_length_in_s=10.0,
... ).audios[0]
>>> # save generated audio sample
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
```"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment