Unverified Commit 69e72b1d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Stable Audio integration (#8716)



* WIP modeling code and pipeline

* add custom attention processor + custom activation + add to init

* correct ProjectionModel forward

* add stable audio to __initèè

* add autoencoder and update pipeline and modeling code

* add half Rope

* add partial rotary v2

* add temporary modfis to scheduler

* add EDM DPM Solver

* remove TODOs

* clean GLU

* remove att.group_norm to attn processor

* revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

* refactor GLU -> SwiGLU

* remove redundant args

* add channel multiples in autoencoder docstrings

* changes in docsrtings and copyright headers

* clean pipeline

* further cleaning

* remove peft and lora and fromoriginalmodel

* Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace

* make style

* dummy models

* fix copied from

* add fast oobleck tests

* add brownian tree

* oobleck autoencoder slow tests

* remove TODO

* fast stable audio pipeline tests

* add slow tests

* make style

* add first version of docs

* wrap is_torchsde_available to the scheduler

* fix slow test

* test with input waveform

* add input waveform

* remove some todos

* create stableaudio gaussian projection + make style

* add pipeline to toctree

* fix copied from

* make quality

* refactor timestep_features->time_proj

* refactor joint_attention_kwargs->cross_attention_kwargs

* remove forward_chunk

* move StableAudioDitModel to transformers folder

* correct convert + remove partial rotary embed

* apply suggestions from yiyixuxu -> removing attn.kv_heads

* remove temb

* remove cross_attention_kwargs

* further removal of cross_attention_kwargs

* remove text encoder autocast to fp16

* continue removing autocast

* make style

* refactor how text and audio are embedded

* add paper

* update example code

* make style

* unify projection model forward + fix device placement

* make style

* remove fuse qkv

* apply suggestions from review

* Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* make style

* smaller models in fast tests

* pass sequential offloading fast tests

* add docs for vae and autoencoder

* make style and update example

* remove useless import

* add cosine scheduler

* dummy classes

* cosine scheduler docs

* better description of scheduler

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 8c4856cd
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
from transformers import (
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from ...models import AutoencoderOobleck, StableAudioDiTModel
from ...models.embeddings import get_1d_rotary_pos_embed
from ...schedulers import EDMDPMSolverMultistepScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_stable_audio import StableAudioProjectionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import scipy
>>> import torch
>>> import soundfile as sf
>>> from diffusers import StableAudioPipeline
>>> repo_id = "ylacombe/stable-audio-1.0" # TODO (YL): change once set
>>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> # define the prompts
>>> prompt = "The sound of a hammer hitting a wooden surface."
>>> negative_prompt = "Low quality."
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # run the generation
>>> audio = pipe(
... prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=200,
... audio_end_in_s=10.0,
... num_waveforms_per_prompt=3,
... generator=generator,
... ).audios
>>> output = audio[0].T.float().cpu().numpy()
>>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
```
"""
class StableAudioPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-audio generation using StableAudio.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderOobleck`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.T5EncoderModel`]):
Frozen text-encoder. StableAudio uses the encoder of
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
projection_model ([`StableAudioProjectionModel`]):
A trained model used to linearly project the hidden-states from the text encoder model and the start and
end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
give the input to the transformer model.
tokenizer ([`~transformers.T5Tokenizer`]):
Tokenizer to tokenize text for the frozen text-encoder.
transformer ([`StableAudioDiTModel`]):
A `StableAudioDiTModel` to denoise the encoded audio latents.
scheduler ([`EDMDPMSolverMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
"""
model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
def __init__(
self,
vae: AutoencoderOobleck,
text_encoder: T5EncoderModel,
projection_model: StableAudioProjectionModel,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
transformer: StableAudioDiTModel,
scheduler: EDMDPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
projection_model=projection_model,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# 1. Tokenize text
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids.to(device)
attention_mask = attention_mask.to(device)
# 2. Text encoder forward
self.text_encoder.eval()
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
if do_classifier_free_guidance and negative_prompt is not None:
uncond_tokens: List[str]
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# 1. Tokenize text
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
# 2. Text encoder forward
self.text_encoder.eval()
negative_prompt_embeds = self.text_encoder(
uncond_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if negative_attention_mask is not None:
# set the masked tokens to the null embed
negative_prompt_embeds = torch.where(
negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
)
# 3. Project prompt_embeds and negative_prompt_embeds
if do_classifier_free_guidance and negative_prompt_embeds is not None:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the negative and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if attention_mask is not None and negative_attention_mask is None:
negative_attention_mask = torch.ones_like(attention_mask)
elif attention_mask is None and negative_attention_mask is not None:
attention_mask = torch.ones_like(negative_attention_mask)
if attention_mask is not None:
attention_mask = torch.cat([negative_attention_mask, attention_mask])
prompt_embeds = self.projection_model(
text_hidden_states=prompt_embeds,
).text_hidden_states
if attention_mask is not None:
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
return prompt_embeds
def encode_duration(
self,
audio_start_in_s,
audio_end_in_s,
device,
do_classifier_free_guidance,
batch_size,
):
audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
if len(audio_start_in_s) == 1:
audio_start_in_s = audio_start_in_s * batch_size
if len(audio_end_in_s) == 1:
audio_end_in_s = audio_end_in_s * batch_size
# Cast the inputs to floats
audio_start_in_s = [float(x) for x in audio_start_in_s]
audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
audio_end_in_s = [float(x) for x in audio_end_in_s]
audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
projection_output = self.projection_model(
start_seconds=audio_start_in_s,
end_seconds=audio_end_in_s,
)
seconds_start_hidden_states = projection_output.seconds_start_hidden_states
seconds_end_hidden_states = projection_output.seconds_end_hidden_states
# For classifier free guidance, we need to do two forward passes.
# Here we repeat the audio hidden states to avoid doing two forward passes
if do_classifier_free_guidance:
seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
return seconds_start_hidden_states, seconds_end_hidden_states
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
audio_start_in_s,
audio_end_in_s,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
attention_mask=None,
negative_attention_mask=None,
initial_audio_waveforms=None,
initial_audio_sampling_rate=None,
):
if audio_end_in_s < audio_start_in_s:
raise ValueError(
f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
)
if (
audio_start_in_s < self.projection_model.config.min_value
or audio_start_in_s > self.projection_model.config.max_value
):
raise ValueError(
f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
f"is {audio_start_in_s}."
)
if (
audio_end_in_s < self.projection_model.config.min_value
or audio_end_in_s > self.projection_model.config.max_value
):
raise ValueError(
f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
f"is {audio_end_in_s}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and (prompt_embeds is None):
raise ValueError(
"Provide either `prompt`, or `prompt_embeds`. Cannot leave"
"`prompt` undefined without specifying `prompt_embeds`."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
raise ValueError(
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
)
if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
raise ValueError(
"`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
)
if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
raise ValueError(
f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
"Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
)
def prepare_latents(
self,
batch_size,
num_channels_vae,
sample_size,
dtype,
device,
generator,
latents=None,
initial_audio_waveforms=None,
num_waveforms_per_prompt=None,
audio_channels=None,
):
shape = (batch_size, num_channels_vae, sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# encode the initial audio for use by the model
if initial_audio_waveforms is not None:
# check dimension
if initial_audio_waveforms.ndim == 2:
initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
elif initial_audio_waveforms.ndim != 3:
raise ValueError(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
raise ValueError(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
)
# crop or pad
audio_length = initial_audio_waveforms.shape[-1]
if audio_length < audio_vae_length:
logger.warning(
f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
)
elif audio_length > audio_vae_length:
logger.warning(
f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
)
audio = initial_audio_waveforms.new_zeros(audio_shape)
audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
latents = encoded_audio + latents
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
audio_end_in_s: Optional[float] = None,
audio_start_in_s: Optional[float] = 0.0,
num_inference_steps: int = 100,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_waveforms_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
initial_audio_waveforms: Optional[torch.Tensor] = None,
initial_audio_sampling_rate: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
output_type: Optional[str] = "pt",
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
audio_end_in_s (`float`, *optional*, defaults to 47.55):
Audio end index in seconds.
audio_start_in_s (`float`, *optional*, defaults to 0):
Audio start index in seconds.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
initial_audio_waveforms (`torch.Tensor`, *optional*):
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
corresponds to the number of prompts passed to the model.
initial_audio_sampling_rate (`int`, *optional*):
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
`negative_prompt` input argument.
attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
be computed from `prompt` input argument.
negative_attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
model (LDM) output.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated audio.
"""
# 0. Convert audio input length from seconds to latent length
downsample_ratio = self.vae.hop_length
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
if audio_end_in_s is None:
audio_end_in_s = max_audio_length_in_s
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
waveform_length = int(self.transformer.config.sample_size)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
audio_start_in_s,
audio_end_in_s,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
initial_audio_waveforms,
initial_audio_sampling_rate,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
)
# Encode duration
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
audio_start_in_s,
audio_end_in_s,
device,
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
batch_size,
)
# Create text_audio_duration_embeds and audio_duration_embeds
text_audio_duration_embeds = torch.cat(
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
)
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
# to concatenate it to the embeddings
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
negative_text_audio_duration_embeds = torch.zeros_like(
text_audio_duration_embeds, device=text_audio_duration_embeds.device
)
text_audio_duration_embeds = torch.cat(
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
)
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
text_audio_duration_embeds = text_audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
)
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
audio_duration_embeds = audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_vae = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_vae,
waveform_length,
text_audio_duration_embeds.dtype,
device,
generator,
latents,
initial_audio_waveforms,
num_waveforms_per_prompt,
audio_channels=self.vae.config.audio_channels,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare rotary positional embedding
rotary_embedding = get_1d_rotary_pos_embed(
self.rotary_embed_dim,
latents.shape[2] + audio_duration_embeds.shape[1],
use_real=True,
repeat_interleave_real=False,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t.unsqueeze(0),
encoder_hidden_states=text_audio_duration_embeds,
global_hidden_states=audio_duration_embeds,
rotary_embedding=rotary_embedding,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 9. Post-processing
if not output_type == "latent":
audio = self.vae.decode(latents).sample
else:
return AudioPipelineOutput(audios=latents)
audio = audio[:, :, waveform_start:waveform_end]
if output_type == "np":
audio = audio.cpu().float().numpy()
self.maybe_free_model_hooks()
if not return_dict:
return (audio,)
return AudioPipelineOutput(audios=audio)
......@@ -118,6 +118,7 @@ except OptionalDependencyNotAvailable:
_dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects))
else:
_import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
......@@ -205,6 +206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
else:
......
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021).
This scheduler was used in Stable Audio Open [1].
[1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
sigma_min (`float`, *optional*, defaults to 0.3):
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
sigma_max (`float`, *optional*, defaults to 500):
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
sigma_data (`float`, *optional*, defaults to 1.0):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
sigma_schedule (`str`, *optional*, defaults to `exponential`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
prediction_type (`str`, defaults to `v_prediction`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
sigma_min: float = 0.3,
sigma_max: float = 500,
sigma_data: float = 1.0,
sigma_schedule: str = "exponential",
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "v_prediction",
rho: float = 7.0,
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
ramp = torch.linspace(0, 1, num_train_timesteps)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
# setable values
self.num_inference_steps = None
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
scaled_sample = sample * c_in
return scaled_sample
def precondition_noise(self, sigma):
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
return sigma.atan() / math.pi * 2
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
if self.config.prediction_type == "epsilon":
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
elif self.config.prediction_type == "v_prediction":
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
else:
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
denoised = c_skip * sample + c_out * model_output
return denoised
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = self.precondition_inputs(sample, sigma)
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = self.config.sigma_min
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# if a noise sampler is used, reinitialise it
self.noise_sampler = None
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
rho = self.config.rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
return alpha_t, sigma_t
def convert_model_output(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
<Tip>
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
</Tip>
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
sigma = self.sigmas[self.step_index]
x0_pred = self.precondition_outputs(sample, model_output, sigma)
return x0_pred
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
# sde-dpmsolver++
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
if self.noise_sampler is None:
seed = None
if generator is not None:
seed = (
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
self.noise_sampler = BrownianTreeNoiseSampler(
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
model_output.device
)
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
......@@ -134,7 +134,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
# setable values
self.num_inference_steps = None
......
......@@ -62,6 +62,21 @@ class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderOobleck(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
......@@ -377,6 +392,21 @@ class SparseControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class StableAudioDiTModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class T2IAdapter(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
class CosineDPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "torchsde"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
class DPMSolverSDEScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"]
......
......@@ -992,6 +992,36 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableAudioPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableAudioProjectionModel(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadeCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
......@@ -18,12 +18,14 @@ import unittest
import numpy as np
import torch
from datasets import load_dataset
from parameterized import parameterized
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
......@@ -128,6 +130,18 @@ def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
}
def get_autoencoder_oobleck_config(block_out_channels=None):
init_dict = {
"encoder_hidden_size": 12,
"decoder_channels": 12,
"decoder_input_channels": 6,
"audio_channels": 2,
"downsampling_ratios": [2, 4],
"channel_multiples": [1, 2],
}
return init_dict
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
......@@ -480,6 +494,41 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@property
def dummy_input(self):
batch_size = 4
num_channels = 2
seq_len = 24
waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
return {"sample": waveform, "sample_posterior": False}
@property
def input_shape(self):
return (2, 24)
@property
def output_shape(self):
return (2, 24)
def prepare_init_args_and_inputs_for_common(self):
init_dict = get_autoencoder_oobleck_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_forward_signature(self):
pass
def test_forward_with_norm_groups(self):
pass
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
......@@ -1100,3 +1149,118 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
for shape in shapes:
image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
pipe.vae.decode(image)
@slow
class AutoencoderOobleckIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def _load_datasamples(self, num_samples):
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
)
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return torch.nn.utils.rnn.pad_sequence(
[torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
)
def get_audio(self, audio_sample_size=2097152, fp16=False):
dtype = torch.float16 if fp16 else torch.float32
audio = self._load_datasamples(2).to(torch_device).to(dtype)
# pad / crop to audio_sample_size
audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
# todo channel
audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
return audio
def get_oobleck_vae_model(
self, model_id="ylacombe/stable-audio-1.0", fp16=False
): # TODO (YL): change repo id once moved
torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderOobleck.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch_dtype,
)
model.to(torch_device)
return model
def get_generator(self, seed=0):
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
return torch.Generator(device=generator_device).manual_seed(seed)
return torch.manual_seed(seed)
@parameterized.expand(
[
# fmt: off
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
# fmt: on
]
)
def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
generator = self.get_generator(seed)
with torch.no_grad():
sample = model(audio, generator=generator, sample_posterior=True).sample
assert sample.shape == audio.shape
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
output_slice = sample[-1, 1, 5:10].cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
def test_stable_diffusion_mode(self):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
with torch.no_grad():
sample = model(audio, sample_posterior=False).sample
assert sample.shape == audio.shape
@parameterized.expand(
[
# fmt: off
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
# fmt: on
]
)
def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
generator = self.get_generator(seed)
with torch.no_grad():
x = audio
posterior = model.encode(x).latent_dist
z = posterior.sample(generator=generator)
sample = model.decode(z).sample
# (batch_size, latent_dim, sequence_length)
assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
assert sample.shape == audio.shape
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
output_slice = sample[-1, 1, 5:10].cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import (
T5EncoderModel,
T5Tokenizer,
)
from diffusers import (
AutoencoderOobleck,
CosineDPMSolverMultistepScheduler,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableAudioPipeline
params = frozenset(
[
"prompt",
"audio_end_in_s",
"audio_start_in_s",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
"initial_audio_waveforms",
]
)
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"num_waveforms_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = StableAudioDiTModel(
sample_size=4,
in_channels=3,
num_layers=2,
attention_head_dim=4,
num_key_value_attention_heads=2,
out_channels=3,
cross_attention_dim=4,
time_proj_dim=8,
global_states_input_dim=8,
cross_attention_input_dim=4,
)
scheduler = CosineDPMSolverMultistepScheduler(
solver_order=2,
prediction_type="v_prediction",
sigma_data=1.0,
sigma_schedule="exponential",
)
torch.manual_seed(0)
vae = AutoencoderOobleck(
encoder_hidden_size=6,
downsampling_ratios=[1, 2],
decoder_channels=3,
decoder_input_channels=3,
audio_channels=2,
channel_multiples=[2, 4],
sampling_rate=4,
)
torch.manual_seed(0)
t5_repo_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"
text_encoder = T5EncoderModel.from_pretrained(t5_repo_id)
tokenizer = T5Tokenizer.from_pretrained(t5_repo_id, truncation=True, model_max_length=25)
torch.manual_seed(0)
projection_model = StableAudioProjectionModel(
text_encoder_dim=text_encoder.config.d_model,
conditioning_dim=4,
min_value=0,
max_value=32,
)
components = {
"transformer": transformer,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"projection_model": projection_model,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
}
return inputs
def test_save_load_local(self):
# increase tolerance from 1e-4 -> 7e-3 to account for large composite model
super().test_save_load_local(expected_max_difference=7e-3)
def test_save_load_optional_components(self):
# increase tolerance from 1e-4 -> 7e-3 to account for large composite model
super().test_save_load_optional_components(expected_max_difference=7e-3)
def test_stable_audio_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = stable_audio_pipe(**inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape == (2, 7)
def test_stable_audio_without_prompts(self):
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = stable_audio_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = stable_audio_pipe.tokenizer(
prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
prompt_embeds = stable_audio_pipe.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)[0]
inputs["prompt_embeds"] = prompt_embeds
inputs["attention_mask"] = attention_mask
# forward
output = stable_audio_pipe(**inputs)
audio_2 = output.audios[0]
assert (audio_1 - audio_2).abs().max() < 1e-2
def test_stable_audio_negative_without_prompts(self):
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = stable_audio_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = stable_audio_pipe.tokenizer(
prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
prompt_embeds = stable_audio_pipe.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)[0]
inputs["prompt_embeds"] = prompt_embeds
inputs["attention_mask"] = attention_mask
negative_text_inputs = stable_audio_pipe.tokenizer(
negative_prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
negative_text_input_ids = negative_text_inputs.input_ids
negative_attention_mask = negative_text_inputs.attention_mask
negative_prompt_embeds = stable_audio_pipe.text_encoder(
negative_text_input_ids,
attention_mask=negative_attention_mask,
)[0]
inputs["negative_prompt_embeds"] = negative_prompt_embeds
inputs["negative_attention_mask"] = negative_attention_mask
# forward
output = stable_audio_pipe(**inputs)
audio_2 = output.audios[0]
assert (audio_1 - audio_2).abs().max() < 1e-2
def test_stable_audio_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "egg cracking"
output = stable_audio_pipe(**inputs, negative_prompt=negative_prompt)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape == (2, 7)
def test_stable_audio_num_waveforms_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
# test num_waveforms_per_prompt=1 (default)
audios = stable_audio_pipe(prompt, num_inference_steps=2).audios
assert audios.shape == (1, 2, 7)
# test num_waveforms_per_prompt=1 (default) for batch of prompts
batch_size = 2
audios = stable_audio_pipe([prompt] * batch_size, num_inference_steps=2).audios
assert audios.shape == (batch_size, 2, 7)
# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 2
audios = stable_audio_pipe(
prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (num_waveforms_per_prompt, 2, 7)
# test num_waveforms_per_prompt for batch of prompts
batch_size = 2
audios = stable_audio_pipe(
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7)
def test_stable_audio_audio_end_in_s(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = stable_audio_pipe(audio_end_in_s=1.5, **inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.5
output = stable_audio_pipe(audio_end_in_s=1.1875, **inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.0
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=5e-4)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
def test_stable_audio_input_waveform(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
initial_audio_waveforms = torch.ones((1, 5))
# test raises error when no sampling rate
with self.assertRaises(ValueError):
audios = stable_audio_pipe(
prompt, num_inference_steps=2, initial_audio_waveforms=initial_audio_waveforms
).audios
# test raises error when wrong sampling rate
with self.assertRaises(ValueError):
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate - 1,
).audios
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (1, 2, 7)
# test works with num_waveforms_per_prompt
num_waveforms_per_prompt = 2
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
num_waveforms_per_prompt=num_waveforms_per_prompt,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (num_waveforms_per_prompt, 2, 7)
# test num_waveforms_per_prompt for batch of prompts and input audio (two channels)
batch_size = 2
initial_audio_waveforms = torch.ones((batch_size, 2, 5))
audios = stable_audio_pipe(
[prompt] * batch_size,
num_inference_steps=2,
num_waveforms_per_prompt=num_waveforms_per_prompt,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7)
@unittest.skip("Not supported yet")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Not supported yet")
def test_sequential_offload_forward_pass_twice(self):
pass
@nightly
@require_torch_gpu
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 64, 1024))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"audio_end_in_s": 30,
"guidance_scale": 2.5,
}
return inputs
def test_stable_audio(self):
stable_audio_pipe = StableAudioPipeline.from_pretrained(
"ylacombe/stable-audio-1.0"
) # TODO (YL): change once changed
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = stable_audio_pipe(**inputs).audios[0]
assert audio.ndim == 2
assert audio.shape == (2, int(inputs["audio_end_in_s"] * stable_audio_pipe.vae.sampling_rate))
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
expected_slice = np.array(
[-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
)
# fmt: one
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment