Unverified Commit b94880e5 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Add AudioLDM (#2232)



* Add AudioLDM

* up

* add vocoder

* start unet

* unconditional unet

* clap, vocoder and vae

* clean-up: conversion scripts

* fix: conversion script token_type_ids

* clean-up: pipeline docstring

* tests: from SD

* clean-up: cpu offload vocoder instead of safety checker

* feat: adapt tests to audioldm

* feat: add docs

* clean-up: amend pipeline docstrings

* clean-up: make style

* clean-up: make fix-copies

* fix: add doc path to toctree

* clean-up: args for conversion script

* clean-up: paths to checkpoints

* fix: use conditional unet

* clean-up: make style

* fix: type hints for UNet

* clean-up: docstring for UNet

* clean-up: make style

* clean-up: remove duplicate in docstring

* clean-up: make style

* clean-up: make fix-copies

* clean-up: move imports to start in code snippet

* fix: pass cross_attention_dim as a list/tuple to unet

* clean-up: make fix-copies

* fix: update checkpoint path

* fix: unet cross_attention_dim in tests

* film embeddings -> class embeddings

* Apply suggestions from code review
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* fix: unet film embed to use existing args

* fix: unet tests to use existing args

* fix: make style

* fix: transformers import and version in init

* clean-up: make style

* Revert "clean-up: make style"

This reverts commit 5d6d1f8b324f5583e7805dc01e2c86e493660d66.

* clean-up: make style

* clean-up: use pipeline tester mixin tests where poss

* clean-up: skip attn slicing test

* fix: add torch dtype to docs

* fix: remove conversion script out of src

* fix: remove .detach from 1d waveform

* fix: reduce default num inf steps

* fix: swap height/width -> audio_length_in_s

* clean-up: make style

* fix: remove nightly tests

* fix: imports in conversion script

* clean-up: slim-down to two slow tests

* clean-up: slim-down fast tests

* fix: batch consistent tests

* clean-up: make style

* clean-up: remove vae slicing fast test

* clean-up: propagate changes to doc

* fix: increase test tol to 1e-2

* clean-up: finish docs

* clean-up: make style

* feat: vocoder / VAE compatibility check

* feat: possibly expand / cut audio waveform

* fix: pipeline call signature test

* fix: slow tests output len

* clean-up: make style

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
parent 1870fb05
......@@ -134,6 +134,8 @@
title: AltDiffusion
- local: api/pipelines/audio_diffusion
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/cycle_diffusion
title: Cycle Diffusion
- local: api/pipelines/dance_diffusion
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AudioLDM
## Overview
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://arxiv.org/abs/2301.12503) by Haohe Liu et al.
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
latents. AudioLDM takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional
sound effects, human speech and music.
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found [here](https://github.com/haoheliu/AudioLDM).
## Text-to-Audio
The [`AudioLDMPipeline`] can be used to load pre-trained weights from [cvssp/audioldm](https://huggingface.co/cvssp/audioldm) and generate text-conditional audio outputs:
```python
from diffusers import AudioLDMPipeline
import torch
import scipy
repo_id = "cvssp/audioldm"
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
# save the audio sample as a .wav file
scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
```
### Tips
Prompts:
* Descriptive prompt inputs work best: you can use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g., "water stream in a forest" instead of "stream").
* It's best to use general terms like 'cat' or 'dog' instead of specific names or abstract objects that the model may not be familiar with.
Inference:
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument: higher steps give higher quality audio at the expense of slower inference.
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
### How to load and use different schedulers
The AudioLDM pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers
that can be used with the AudioLDM pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
[`EulerAncestralDiscreteScheduler`] etc. We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest
scheduler there is.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`]
method, or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the
[`DPMSolverMultistepScheduler`], you can do the following:
```python
>>> from diffusers import AudioLDMPipeline, DPMSolverMultistepScheduler
>>> import torch
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
>>> pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained("cvssp/audioldm", subfolder="scheduler")
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", scheduler=dpm_scheduler, torch_dtype=torch.float16)
```
## AudioLDMPipeline
[[autodoc]] AudioLDMPipeline
- all
- __call__
This diff is collapsed.
......@@ -112,6 +112,7 @@ else:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AudioLDMPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,
......
......@@ -86,13 +86,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
......@@ -106,6 +107,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
"""
_supports_gradient_checkpointing = True
......@@ -135,7 +138,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
......@@ -149,6 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
):
super().__init__()
......@@ -175,6 +179,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
......@@ -228,6 +237,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
......@@ -240,6 +255,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
......@@ -252,12 +278,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
......@@ -272,12 +298,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
......@@ -287,11 +313,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
......@@ -307,6 +333,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
......@@ -330,12 +357,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -571,6 +598,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
# 2. pre-process
......
......@@ -44,6 +44,7 @@ except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
......
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AudioLDMPipeline,
)
else:
from .pipeline_audioldm import AudioLDMPipeline
This diff is collapsed.
......@@ -167,13 +167,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
......@@ -187,6 +188,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
"""
_supports_gradient_checkpointing = True
......@@ -221,7 +224,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
......@@ -235,6 +238,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
):
super().__init__()
......@@ -265,6 +269,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
......@@ -318,6 +328,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
......@@ -330,6 +346,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
......@@ -342,12 +369,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
......@@ -362,12 +389,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if mid_block_type == "UNetMidBlockFlatCrossAttn":
self.mid_block = UNetMidBlockFlatCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
......@@ -377,11 +404,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
......@@ -397,6 +424,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
......@@ -420,12 +448,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -661,6 +689,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
# 2. pre-process
......
......@@ -32,6 +32,21 @@ class AltDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AudioLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
......@@ -199,6 +199,74 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_simple_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_class_embeddings_concat(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
init_dict["class_embeddings_concat"] = True
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
......@@ -103,6 +103,19 @@ UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
"audio_length_in_s",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
"cross_attention_kwargs",
]
)
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
ClapTextConfig,
ClapTextModelWithProjection,
RobertaTokenizer,
SpeechT5HifiGan,
SpeechT5HifiGanConfig,
)
from diffusers import (
AudioLDMPipeline,
AutoencoderKL,
DDIMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from ...pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AudioLDMPipeline
params = TEXT_TO_AUDIO_PARAMS
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"num_waveforms_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=(32, 64),
class_embed_type="simple_projection",
projection_class_embeddings_input_dim=32,
class_embeddings_concat=True,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=1,
out_channels=1,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = ClapTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
projection_dim=32,
)
text_encoder = ClapTextModelWithProjection(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
vocoder_config = SpeechT5HifiGanConfig(
model_in_dim=8,
sampling_rate=16000,
upsample_initial_channel=16,
upsample_rates=[2, 2],
upsample_kernel_sizes=[4, 4],
resblock_kernel_sizes=[3, 7],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
normalize_before=False,
)
vocoder = SpeechT5HifiGan(vocoder_config)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
}
return inputs
def test_audioldm_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(**inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-2
def test_audioldm_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = audioldm_pipe.tokenizer(
prompt,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
prompt_embeds = audioldm_pipe.text_encoder(
text_inputs,
)
prompt_embeds = prompt_embeds.text_embeds
# additional L_2 normalization over each hidden-state
prompt_embeds = F.normalize(prompt_embeds, dim=-1)
inputs["prompt_embeds"] = prompt_embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
def test_audioldm_negative_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = audioldm_pipe.tokenizer(
p,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
text_embeds = audioldm_pipe.text_encoder(
text_inputs,
)
text_embeds = text_embeds.text_embeds
# additional L_2 normalization over each hidden-state
text_embeds = F.normalize(text_embeds, dim=-1)
embeds.append(text_embeds)
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
def test_audioldm_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "egg cracking"
output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-2
def test_audioldm_num_waveforms_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
# test num_waveforms_per_prompt=1 (default)
audios = audioldm_pipe(prompt, num_inference_steps=2).audios
assert audios.shape == (1, 256)
# test num_waveforms_per_prompt=1 (default) for batch of prompts
batch_size = 2
audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
assert audios.shape == (batch_size, 256)
# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 2
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
assert audios.shape == (num_waveforms_per_prompt, 256)
# test num_waveforms_per_prompt for batch of prompts
batch_size = 2
audios = audioldm_pipe(
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
def test_audioldm_audio_length_in_s(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.016
output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.032
def test_audioldm_vocoder_model_in_dim(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = ["hey"]
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
assert audio_shape == (1, 256)
config = audioldm_pipe.vocoder.config
config.model_in_dim *= 2
audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
assert audio_shape == (1, 256)
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
@slow
# @require_torch_gpu
class AudioLDMPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 2.5,
}
return inputs
def test_audioldm(self):
audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81920
audio_slice = audio[77230:77240]
expected_slice = np.array(
[-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-2
def test_audioldm_lms(self):
audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81920
audio_slice = audio[27780:27790]
expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment