Unverified Commit 69e72b1d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Stable Audio integration (#8716)



* WIP modeling code and pipeline

* add custom attention processor + custom activation + add to init

* correct ProjectionModel forward

* add stable audio to __initèè

* add autoencoder and update pipeline and modeling code

* add half Rope

* add partial rotary v2

* add temporary modfis to scheduler

* add EDM DPM Solver

* remove TODOs

* clean GLU

* remove att.group_norm to attn processor

* revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

* refactor GLU -> SwiGLU

* remove redundant args

* add channel multiples in autoencoder docstrings

* changes in docsrtings and copyright headers

* clean pipeline

* further cleaning

* remove peft and lora and fromoriginalmodel

* Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace

* make style

* dummy models

* fix copied from

* add fast oobleck tests

* add brownian tree

* oobleck autoencoder slow tests

* remove TODO

* fast stable audio pipeline tests

* add slow tests

* make style

* add first version of docs

* wrap is_torchsde_available to the scheduler

* fix slow test

* test with input waveform

* add input waveform

* remove some todos

* create stableaudio gaussian projection + make style

* add pipeline to toctree

* fix copied from

* make quality

* refactor timestep_features->time_proj

* refactor joint_attention_kwargs->cross_attention_kwargs

* remove forward_chunk

* move StableAudioDitModel to transformers folder

* correct convert + remove partial rotary embed

* apply suggestions from yiyixuxu -> removing attn.kv_heads

* remove temb

* remove cross_attention_kwargs

* further removal of cross_attention_kwargs

* remove text encoder autocast to fp16

* continue removing autocast

* make style

* refactor how text and audio are embedded

* add paper

* update example code

* make style

* unify projection model forward + fix device placement

* make style

* remove fuse qkv

* apply suggestions from review

* Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* make style

* smaller models in fast tests

* pass sequential offloading fast tests

* add docs for vae and autoencoder

* make style and update example

* remove useless import

* add cosine scheduler

* dummy classes

* cosine scheduler docs

* better description of scheduler

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 8c4856cd
This diff is collapsed.
...@@ -118,6 +118,7 @@ except OptionalDependencyNotAvailable: ...@@ -118,6 +118,7 @@ except OptionalDependencyNotAvailable:
_dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects)) _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects))
else: else:
_import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"] _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -205,6 +206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -205,6 +206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
else: else:
from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
else: else:
......
...@@ -134,7 +134,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -134,7 +134,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = self.precondition_noise(sigmas) self.timesteps = self.precondition_noise(sigmas)
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
......
...@@ -62,6 +62,21 @@ class AutoencoderKLTemporalDecoder(metaclass=DummyObject): ...@@ -62,6 +62,21 @@ class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoencoderOobleck(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject): class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -377,6 +392,21 @@ class SparseControlNetModel(metaclass=DummyObject): ...@@ -377,6 +392,21 @@ class SparseControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class StableAudioDiTModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class T2IAdapter(metaclass=DummyObject): class T2IAdapter(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class CosineDPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "torchsde"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
class DPMSolverSDEScheduler(metaclass=DummyObject): class DPMSolverSDEScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"] _backends = ["torch", "torchsde"]
......
...@@ -992,6 +992,36 @@ class ShapEPipeline(metaclass=DummyObject): ...@@ -992,6 +992,36 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableAudioPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableAudioProjectionModel(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadeCombinedPipeline(metaclass=DummyObject): class StableCascadeCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -18,12 +18,14 @@ import unittest ...@@ -18,12 +18,14 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset
from parameterized import parameterized from parameterized import parameterized
from diffusers import ( from diffusers import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AutoencoderKL, AutoencoderKL,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny, AutoencoderTiny,
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
StableDiffusionPipeline, StableDiffusionPipeline,
...@@ -128,6 +130,18 @@ def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None): ...@@ -128,6 +130,18 @@ def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
} }
def get_autoencoder_oobleck_config(block_out_channels=None):
init_dict = {
"encoder_hidden_size": 12,
"decoder_channels": 12,
"decoder_input_channels": 6,
"audio_channels": 2,
"downsampling_ratios": [2, 4],
"channel_multiples": [1, 2],
}
return init_dict
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
main_input_name = "sample" main_input_name = "sample"
...@@ -480,6 +494,41 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase) ...@@ -480,6 +494,41 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@property
def dummy_input(self):
batch_size = 4
num_channels = 2
seq_len = 24
waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
return {"sample": waveform, "sample_posterior": False}
@property
def input_shape(self):
return (2, 24)
@property
def output_shape(self):
return (2, 24)
def prepare_init_args_and_inputs_for_common(self):
init_dict = get_autoencoder_oobleck_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_forward_signature(self):
pass
def test_forward_with_norm_groups(self):
pass
@slow @slow
class AutoencoderTinyIntegrationTests(unittest.TestCase): class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
...@@ -1100,3 +1149,118 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ...@@ -1100,3 +1149,118 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
for shape in shapes: for shape in shapes:
image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype) image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
pipe.vae.decode(image) pipe.vae.decode(image)
@slow
class AutoencoderOobleckIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def _load_datasamples(self, num_samples):
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
)
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return torch.nn.utils.rnn.pad_sequence(
[torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
)
def get_audio(self, audio_sample_size=2097152, fp16=False):
dtype = torch.float16 if fp16 else torch.float32
audio = self._load_datasamples(2).to(torch_device).to(dtype)
# pad / crop to audio_sample_size
audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
# todo channel
audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
return audio
def get_oobleck_vae_model(
self, model_id="ylacombe/stable-audio-1.0", fp16=False
): # TODO (YL): change repo id once moved
torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderOobleck.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch_dtype,
)
model.to(torch_device)
return model
def get_generator(self, seed=0):
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
return torch.Generator(device=generator_device).manual_seed(seed)
return torch.manual_seed(seed)
@parameterized.expand(
[
# fmt: off
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
# fmt: on
]
)
def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
generator = self.get_generator(seed)
with torch.no_grad():
sample = model(audio, generator=generator, sample_posterior=True).sample
assert sample.shape == audio.shape
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
output_slice = sample[-1, 1, 5:10].cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
def test_stable_diffusion_mode(self):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
with torch.no_grad():
sample = model(audio, sample_posterior=False).sample
assert sample.shape == audio.shape
@parameterized.expand(
[
# fmt: off
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
# fmt: on
]
)
def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
model = self.get_oobleck_vae_model()
audio = self.get_audio()
generator = self.get_generator(seed)
with torch.no_grad():
x = audio
posterior = model.encode(x).latent_dist
z = posterior.sample(generator=generator)
sample = model.decode(z).sample
# (batch_size, latent_dim, sequence_length)
assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
assert sample.shape == audio.shape
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
output_slice = sample[-1, 1, 5:10].cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import (
T5EncoderModel,
T5Tokenizer,
)
from diffusers import (
AutoencoderOobleck,
CosineDPMSolverMultistepScheduler,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableAudioPipeline
params = frozenset(
[
"prompt",
"audio_end_in_s",
"audio_start_in_s",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
"initial_audio_waveforms",
]
)
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"num_waveforms_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = StableAudioDiTModel(
sample_size=4,
in_channels=3,
num_layers=2,
attention_head_dim=4,
num_key_value_attention_heads=2,
out_channels=3,
cross_attention_dim=4,
time_proj_dim=8,
global_states_input_dim=8,
cross_attention_input_dim=4,
)
scheduler = CosineDPMSolverMultistepScheduler(
solver_order=2,
prediction_type="v_prediction",
sigma_data=1.0,
sigma_schedule="exponential",
)
torch.manual_seed(0)
vae = AutoencoderOobleck(
encoder_hidden_size=6,
downsampling_ratios=[1, 2],
decoder_channels=3,
decoder_input_channels=3,
audio_channels=2,
channel_multiples=[2, 4],
sampling_rate=4,
)
torch.manual_seed(0)
t5_repo_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"
text_encoder = T5EncoderModel.from_pretrained(t5_repo_id)
tokenizer = T5Tokenizer.from_pretrained(t5_repo_id, truncation=True, model_max_length=25)
torch.manual_seed(0)
projection_model = StableAudioProjectionModel(
text_encoder_dim=text_encoder.config.d_model,
conditioning_dim=4,
min_value=0,
max_value=32,
)
components = {
"transformer": transformer,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"projection_model": projection_model,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
}
return inputs
def test_save_load_local(self):
# increase tolerance from 1e-4 -> 7e-3 to account for large composite model
super().test_save_load_local(expected_max_difference=7e-3)
def test_save_load_optional_components(self):
# increase tolerance from 1e-4 -> 7e-3 to account for large composite model
super().test_save_load_optional_components(expected_max_difference=7e-3)
def test_stable_audio_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = stable_audio_pipe(**inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape == (2, 7)
def test_stable_audio_without_prompts(self):
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = stable_audio_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = stable_audio_pipe.tokenizer(
prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
prompt_embeds = stable_audio_pipe.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)[0]
inputs["prompt_embeds"] = prompt_embeds
inputs["attention_mask"] = attention_mask
# forward
output = stable_audio_pipe(**inputs)
audio_2 = output.audios[0]
assert (audio_1 - audio_2).abs().max() < 1e-2
def test_stable_audio_negative_without_prompts(self):
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = stable_audio_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = stable_audio_pipe.tokenizer(
prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
prompt_embeds = stable_audio_pipe.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)[0]
inputs["prompt_embeds"] = prompt_embeds
inputs["attention_mask"] = attention_mask
negative_text_inputs = stable_audio_pipe.tokenizer(
negative_prompt,
padding="max_length",
max_length=stable_audio_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to(torch_device)
negative_text_input_ids = negative_text_inputs.input_ids
negative_attention_mask = negative_text_inputs.attention_mask
negative_prompt_embeds = stable_audio_pipe.text_encoder(
negative_text_input_ids,
attention_mask=negative_attention_mask,
)[0]
inputs["negative_prompt_embeds"] = negative_prompt_embeds
inputs["negative_attention_mask"] = negative_attention_mask
# forward
output = stable_audio_pipe(**inputs)
audio_2 = output.audios[0]
assert (audio_1 - audio_2).abs().max() < 1e-2
def test_stable_audio_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "egg cracking"
output = stable_audio_pipe(**inputs, negative_prompt=negative_prompt)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape == (2, 7)
def test_stable_audio_num_waveforms_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
# test num_waveforms_per_prompt=1 (default)
audios = stable_audio_pipe(prompt, num_inference_steps=2).audios
assert audios.shape == (1, 2, 7)
# test num_waveforms_per_prompt=1 (default) for batch of prompts
batch_size = 2
audios = stable_audio_pipe([prompt] * batch_size, num_inference_steps=2).audios
assert audios.shape == (batch_size, 2, 7)
# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 2
audios = stable_audio_pipe(
prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (num_waveforms_per_prompt, 2, 7)
# test num_waveforms_per_prompt for batch of prompts
batch_size = 2
audios = stable_audio_pipe(
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7)
def test_stable_audio_audio_end_in_s(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = stable_audio_pipe(audio_end_in_s=1.5, **inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.5
output = stable_audio_pipe(audio_end_in_s=1.1875, **inputs)
audio = output.audios[0]
assert audio.ndim == 2
assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.0
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=5e-4)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
def test_stable_audio_input_waveform(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
stable_audio_pipe = StableAudioPipeline(**components)
stable_audio_pipe = stable_audio_pipe.to(device)
stable_audio_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
initial_audio_waveforms = torch.ones((1, 5))
# test raises error when no sampling rate
with self.assertRaises(ValueError):
audios = stable_audio_pipe(
prompt, num_inference_steps=2, initial_audio_waveforms=initial_audio_waveforms
).audios
# test raises error when wrong sampling rate
with self.assertRaises(ValueError):
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate - 1,
).audios
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (1, 2, 7)
# test works with num_waveforms_per_prompt
num_waveforms_per_prompt = 2
audios = stable_audio_pipe(
prompt,
num_inference_steps=2,
num_waveforms_per_prompt=num_waveforms_per_prompt,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (num_waveforms_per_prompt, 2, 7)
# test num_waveforms_per_prompt for batch of prompts and input audio (two channels)
batch_size = 2
initial_audio_waveforms = torch.ones((batch_size, 2, 5))
audios = stable_audio_pipe(
[prompt] * batch_size,
num_inference_steps=2,
num_waveforms_per_prompt=num_waveforms_per_prompt,
initial_audio_waveforms=initial_audio_waveforms,
initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate,
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7)
@unittest.skip("Not supported yet")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Not supported yet")
def test_sequential_offload_forward_pass_twice(self):
pass
@nightly
@require_torch_gpu
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 64, 1024))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"audio_end_in_s": 30,
"guidance_scale": 2.5,
}
return inputs
def test_stable_audio(self):
stable_audio_pipe = StableAudioPipeline.from_pretrained(
"ylacombe/stable-audio-1.0"
) # TODO (YL): change once changed
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = stable_audio_pipe(**inputs).audios[0]
assert audio.ndim == 2
assert audio.shape == (2, int(inputs["audio_end_in_s"] * stable_audio_pipe.vae.sampling_rate))
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
expected_slice = np.array(
[-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
)
# fmt: one
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment