Unverified Commit 69e72b1d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Stable Audio integration (#8716)



* WIP modeling code and pipeline

* add custom attention processor + custom activation + add to init

* correct ProjectionModel forward

* add stable audio to __initèè

* add autoencoder and update pipeline and modeling code

* add half Rope

* add partial rotary v2

* add temporary modfis to scheduler

* add EDM DPM Solver

* remove TODOs

* clean GLU

* remove att.group_norm to attn processor

* revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

* refactor GLU -> SwiGLU

* remove redundant args

* add channel multiples in autoencoder docstrings

* changes in docsrtings and copyright headers

* clean pipeline

* further cleaning

* remove peft and lora and fromoriginalmodel

* Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace

* make style

* dummy models

* fix copied from

* add fast oobleck tests

* add brownian tree

* oobleck autoencoder slow tests

* remove TODO

* fast stable audio pipeline tests

* add slow tests

* make style

* add first version of docs

* wrap is_torchsde_available to the scheduler

* fix slow test

* test with input waveform

* add input waveform

* remove some todos

* create stableaudio gaussian projection + make style

* add pipeline to toctree

* fix copied from

* make quality

* refactor timestep_features->time_proj

* refactor joint_attention_kwargs->cross_attention_kwargs

* remove forward_chunk

* move StableAudioDitModel to transformers folder

* correct convert + remove partial rotary embed

* apply suggestions from yiyixuxu -> removing attn.kv_heads

* remove temb

* remove cross_attention_kwargs

* further removal of cross_attention_kwargs

* remove text encoder autocast to fp16

* continue removing autocast

* make style

* refactor how text and audio are embedded

* add paper

* update example code

* make style

* unify projection model forward + fix device placement

* make style

* remove fuse qkv

* apply suggestions from review

* Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* make style

* smaller models in fast tests

* pass sequential offloading fast tests

* add docs for vae and autoencoder

* make style and update example

* remove useless import

* add cosine scheduler

* dummy classes

* cosine scheduler docs

* better description of scheduler

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 8c4856cd
......@@ -239,6 +239,8 @@
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/autoencoder_oobleck
title: Oobleck AutoEncoder
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/transformer2d
......@@ -259,6 +261,8 @@
title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/controlnet
......@@ -362,6 +366,8 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
......@@ -425,6 +431,8 @@
title: CMStochasticIterativeScheduler
- local: api/schedulers/consistency_decoder
title: ConsistencyDecoderScheduler
- local: api/schedulers/cosine_dpm
title: CosineDPMSolverMultistepScheduler
- local: api/schedulers/ddim_inverse
title: DDIMInverseScheduler
- local: api/schedulers/ddim
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoencoderOobleck
The Oobleck variational autoencoder (VAE) model with KL loss was introduced in [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) and [Stable Audio Open](https://huggingface.co/papers/2407.14358) by Stability AI. The model is used in 🤗 Diffusers to encode audio waveforms into latents and to decode latent representations into audio waveforms.
The abstract from the paper is:
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
## AutoencoderOobleck
[[autodoc]] AutoencoderOobleck
- decode
- encode
- all
## OobleckDecoderOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput
## OobleckDecoderOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput
## AutoencoderOobleckOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.AutoencoderOobleckOutput
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# StableAudioDiTModel
A Transformer model for audio waveforms from [Stable Audio Open](https://huggingface.co/papers/2407.14358).
## StableAudioDiTModel
[[autodoc]] StableAudioDiTModel
......@@ -71,6 +71,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Semantic Guidance](semantic_stable_diffusion) | text2image |
| [Shap-E](shap_e) | text-to-3D, image-to-3D |
| [Spectrogram Diffusion](spectrogram_diffusion) | |
| [Stable Audio](stable_audio) | text2audio |
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
| [Stable Diffusion Model Editing](model_editing) | model editing |
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Audio
Stable Audio was proposed in [Stable Audio Open](https://arxiv.org/abs/2407.14358) by Zach Evans et al. . it takes a text prompt as input and predicts the corresponding sound or music sample.
Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder.
Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT.
The abstract of the paper is the following:
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool).
## Tips
When constructing a prompt, keep in mind:
* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".
During inference:
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
## StableAudioPipeline
[[autodoc]] StableAudioPipeline
- all
- __call__
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# CosineDPMSolverMultistepScheduler
The [`CosineDPMSolverMultistepScheduler`] is a variant of [`DPMSolverMultistepScheduler`] with cosine schedule, proposed by Nichol and Dhariwal (2021).
It is being used in the [Stable Audio Open](https://arxiv.org/abs/2407.14358) paper and the [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool) codebase.
This scheduler was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe).
## CosineDPMSolverMultistepScheduler
[[autodoc]] CosineDPMSolverMultistepScheduler
## SchedulerOutput
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import json
import os
from contextlib import nullcontext
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
T5EncoderModel,
)
from diffusers import (
AutoencoderOobleck,
CosineDPMSolverMultistepScheduler,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5):
projection_model_state_dict = {
k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v
for (k, v) in state_dict.items()
if "conditioner.conditioners" in k
}
# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is.
for key, value in list(projection_model_state_dict.items()):
new_key = key.replace("seconds_start", "start_number_conditioner").replace(
"seconds_total", "end_number_conditioner"
)
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key)
model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k}
for key, value in list(model_state_dict.items()):
# attention layers
new_key = (
key.replace("transformer.", "")
.replace("layers", "transformer_blocks")
.replace("self_attn", "attn1")
.replace("cross_attn", "attn2")
.replace("ff.ff", "ff.net")
)
new_key = (
new_key.replace("pre_norm", "norm1")
.replace("cross_attend_norm", "norm2")
.replace("ff_norm", "norm3")
.replace("to_out", "to_out.0")
)
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm
# other layers
new_key = (
new_key.replace("project", "proj")
.replace("to_timestep_embed", "timestep_proj")
.replace("timestep_features", "time_proj")
.replace("to_global_embed", "global_proj")
.replace("to_cond_embed", "cross_attention_proj")
)
# we're using diffusers implementation of time_proj (GaussianFourierProjection) which creates a 1D tensor
if new_key == "time_proj.weight":
model_state_dict[key] = model_state_dict[key].squeeze(1)
if "to_qkv" in new_key:
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0)
model_state_dict[new_key.replace("qkv", "q")] = q
model_state_dict[new_key.replace("qkv", "k")] = k
model_state_dict[new_key.replace("qkv", "v")] = v
elif "to_kv" in new_key:
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0)
model_state_dict[new_key.replace("kv", "k")] = k
model_state_dict[new_key.replace("kv", "v")] = v
else:
model_state_dict[new_key] = model_state_dict.pop(key)
autoencoder_state_dict = {
k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v
for (k, v) in state_dict.items()
if "pretransform.model." in k
}
for key, _ in list(autoencoder_state_dict.items()):
new_key = key
if "coder.layers" in new_key:
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
if "encoder" in new_key:
for i in range(3):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
else:
for i in range(2, 5):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
new_key = new_key.replace("layers.2.beta", "snake2.beta")
new_key = new_key.replace("layers.2.alpha", "snake2.alpha")
new_key = new_key.replace("layers.1.bias", "conv1.bias")
new_key = new_key.replace("layers.1.weight_", "conv1.weight_")
new_key = new_key.replace("layers.3.bias", "conv2.bias")
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx-1}", "snake1")
elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx-1}", "conv2")
else:
new_key = new_key
value = autoencoder_state_dict.pop(key)
if "snake" in new_key:
value = value.unsqueeze(0).unsqueeze(-1)
if new_key in autoencoder_state_dict:
raise ValueError(f"{new_key} already in state dict.")
autoencoder_state_dict[new_key] = value
return model_state_dict, projection_model_state_dict, autoencoder_state_dict
parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline")
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--save_directory",
type=str,
default="./tmp/stable-audio-1.0",
help="Directory to save a pipeline to. Will be created if it doesn't exist.",
)
parser.add_argument(
"--repo_id",
type=str,
default="stable-audio-1.0",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
checkpoint_path = (
os.path.join(args.model_folder_path, "model.safetensors")
if args.use_safetensors
else os.path.join(args.model_folder_path, "model.ckpt")
)
config_path = os.path.join(args.model_folder_path, "model_config.json")
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
with open(config_path) as f_in:
config_dict = json.load(f_in)
conditioning_dict = {
conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"]
}
t5_model_config = conditioning_dict["prompt"]
# T5 Text encoder
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"])
tokenizer = AutoTokenizer.from_pretrained(
t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"]
)
# scheduler
scheduler = CosineDPMSolverMultistepScheduler(
sigma_min=0.3,
sigma_max=500,
solver_order=2,
prediction_type="v_prediction",
sigma_data=1.0,
sigma_schedule="exponential",
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if args.use_safetensors:
orig_state_dict = load_file(checkpoint_path, device=device)
else:
orig_state_dict = torch.load(checkpoint_path, map_location=device)
model_config = config_dict["model"]["diffusion"]["config"]
model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers(
orig_state_dict
)
with ctx():
projection_model = StableAudioProjectionModel(
text_encoder_dim=text_encoder.config.d_model,
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"],
min_value=conditioning_dict["seconds_start"][
"min_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
max_value=conditioning_dict["seconds_start"][
"max_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
)
if is_accelerate_available():
load_model_dict_into_meta(projection_model, projection_model_state_dict)
else:
projection_model.load_state_dict(projection_model_state_dict)
attention_head_dim = model_config["embed_dim"] // model_config["num_heads"]
with ctx():
model = StableAudioDiTModel(
sample_size=int(config_dict["sample_size"])
/ int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]),
in_channels=model_config["io_channels"],
num_layers=model_config["depth"],
attention_head_dim=attention_head_dim,
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim,
num_attention_heads=model_config["num_heads"],
out_channels=model_config["io_channels"],
cross_attention_dim=model_config["cond_token_dim"],
time_proj_dim=256,
global_states_input_dim=model_config["global_cond_dim"],
cross_attention_input_dim=model_config["cond_token_dim"],
)
if is_accelerate_available():
load_model_dict_into_meta(model, model_state_dict)
else:
model.load_state_dict(model_state_dict)
autoencoder_config = config_dict["model"]["pretransform"]["config"]
with ctx():
autoencoder = AutoencoderOobleck(
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"],
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"],
decoder_channels=autoencoder_config["decoder"]["config"]["channels"],
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"],
audio_channels=autoencoder_config["io_channels"],
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"],
sampling_rate=config_dict["sample_rate"],
)
if is_accelerate_available():
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
else:
autoencoder.load_state_dict(autoencoder_state_dict)
# Prior pipeline
pipeline = StableAudioPipeline(
transformer=model,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=scheduler,
vae=autoencoder,
projection_model=projection_model,
)
pipeline.to(dtype).save_pretrained(
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant
)
......@@ -79,6 +79,7 @@ else:
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
......@@ -100,6 +101,7 @@ else:
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
"T2IAdapter",
"T5FilmDecoder",
......@@ -210,7 +212,7 @@ except OptionalDependencyNotAvailable:
]
else:
_import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
_import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
try:
if not (is_torch_available() and is_transformers_available()):
......@@ -293,6 +295,8 @@ else:
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
......@@ -515,6 +519,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
......@@ -536,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3MultiControlNetModel,
SD3Transformer2DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
......@@ -632,7 +638,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .schedulers import DPMSolverSDEScheduler
from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
try:
if not (is_torch_available() and is_transformers_available()):
......@@ -707,6 +713,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
......
......@@ -29,6 +29,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
......@@ -47,6 +48,7 @@ if is_torch_available():
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
......@@ -75,6 +77,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
......@@ -96,6 +99,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
......
......@@ -123,6 +123,28 @@ class GEGLU(nn.Module):
return hidden_states * self.gelu(gate)
class SwiGLU(nn.Module):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
but uses SiLU / Swish instead of GeLU.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.activation(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
......
......@@ -19,7 +19,7 @@ from torch import nn
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
......@@ -820,6 +820,8 @@ class FeedForward(nn.Module):
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
import math
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -49,6 +49,10 @@ class Attention(nn.Module):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
kv_heads (`int`, *optional*, defaults to `None`):
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
Query Attention (MQA) otherwise GQA is used.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
......@@ -1624,6 +1628,137 @@ class AttnProcessor2_0:
return hidden_states
class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Tuple[torch.Tensor],
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
rot_dim = freqs_cis[0].shape[-1]
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
out = torch.cat((x_rotated, x_unrotated), dim=-1)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = query.shape[-1] // attn.heads
kv_heads = key.shape[-1] // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_emb is not None:
query_dtype = query.dtype
key_dtype = key.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
rot_dim = rotary_emb[0].shape[-1]
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
query = torch.cat((query_rotated, query_unrotated), dim=-1)
if not attn.is_cross_attention:
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = torch.cat((key_rotated, key_unrotated), dim=-1)
query = query.to(query_dtype)
key = key.to(key_dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class HunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.alpha.requires_grad = True
self.beta.requires_grad = True
self.logscale = logscale
def forward(self, hidden_states):
shape = hidden_states.shape
alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class OobleckResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state):
"""
Forward pass through the residual unit.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor .
Returns:
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`)
Input tensor after passing through the residual unit.
"""
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class OobleckEncoderBlock(nn.Module):
"""Encoder block used in Oobleck encoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class OobleckDecoderBlock(nn.Module):
"""Decoder block used in Oobleck decoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
def mode(self) -> torch.Tensor:
return self.mean
@dataclass
class AutoencoderOobleckOutput(BaseOutput):
"""
Output of AutoencoderOobleck encoding method.
Args:
latent_dist (`OobleckDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and standard deviation of
`OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
from the distribution.
"""
latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
@dataclass
class OobleckDecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.Tensor
class OobleckEncoder(nn.Module):
"""Oobleck Encoder"""
def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
super().__init__()
strides = downsampling_ratios
channel_multiples = [1] + channel_multiples
# Create first convolution
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
self.block += [
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
]
self.block = nn.ModuleList(self.block)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class OobleckDecoder(nn.Module):
"""Oobleck Decoder"""
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
super().__init__()
strides = upsampling_ratios
channel_multiples = [1] + channel_multiples
# Add first conv layer
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(strides) - stride_index],
output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
stride=stride,
)
]
self.block = nn.ModuleList(block)
output_dim = channels
self.snake1 = Snake1d(output_dim)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class AutoencoderOobleck(ModelMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
encoder_hidden_size (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the encoder.
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
Multiples used to determine the hidden sizes of the hidden layers.
decoder_channels (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the decoder.
decoder_input_channels (`int`, *optional*, defaults to 64):
Input dimension for the decoder. Corresponds to the latent dimension.
audio_channels (`int`, *optional*, defaults to 2):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 44100):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
encoder_hidden_size=128,
downsampling_ratios=[2, 4, 4, 8, 8],
channel_multiples=[1, 2, 4, 8, 16],
decoder_channels=128,
decoder_input_channels=64,
audio_channels=2,
sampling_rate=44100,
):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_channels = decoder_channels
self.upsampling_ratios = downsampling_ratios[::-1]
self.hop_length = int(np.prod(downsampling_ratios))
self.sampling_rate = sampling_rate
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=self.upsampling_ratios,
channel_multiples=channel_multiples,
)
self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
posterior = OobleckDiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderOobleckOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
dec = self.decoder(z)
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.OobleckDecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return OobleckDecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[OobleckDecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
......@@ -352,7 +352,13 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n
def get_1d_rotary_pos_embed(
dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
......@@ -372,6 +378,9 @@ def get_1d_rotary_pos_embed(
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
......@@ -383,10 +392,14 @@ def get_1d_rotary_pos_embed(
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
if use_real and repeat_interleave_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
......@@ -396,6 +409,7 @@ def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
......@@ -417,8 +431,17 @@ def apply_rotary_emb(
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
if use_real_unbind_dim == -1:
# Use for example in Lumina
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
......
......@@ -10,6 +10,7 @@ if is_torch_available():
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
......
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
StableAudioAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableAudioGaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
# Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
del self.weight
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
del self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
@maybe_allow_in_graph
class StableAudioDiTBlock(nn.Module):
r"""
Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
connection and QKNorm
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for the query states.
num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_key_value_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
upcast_attention: bool = False,
norm_eps: float = 1e-5,
ff_inner_dim: Optional[int] = None,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
out_bias=False,
processor=StableAudioAttnProcessor2_0(),
)
# 2. Cross-Attn
self.norm2 = nn.LayerNorm(dim, norm_eps, True)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_key_value_attention_heads,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
out_bias=False,
processor=StableAudioAttnProcessor2_0(),
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn="swiglu",
final_dropout=False,
inner_dim=ff_inner_dim,
bias=True,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
rotary_embedding: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
rotary_emb=rotary_embedding,
)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
The Diffusion Transformer model introduced in Stable Audio.
Reference: https://github.com/Stability-AI/stable-audio-tools
Parameters:
sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
num_key_value_attention_heads (`int`, *optional*, defaults to 12):
The number of heads to use for the key and value states.
out_channels (`int`, defaults to 64): Number of output channels.
cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
global_states_input_dim ( `int`, *optional*, defaults to 1536):
Input dimension of the global hidden states projection.
cross_attention_input_dim ( `int`, *optional*, defaults to 768):
Input dimension of the cross-attention projection
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 1024,
in_channels: int = 64,
num_layers: int = 24,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
num_key_value_attention_heads: int = 12,
out_channels: int = 64,
cross_attention_dim: int = 768,
time_proj_dim: int = 256,
global_states_input_dim: int = 1536,
cross_attention_input_dim: int = 768,
):
super().__init__()
self.sample_size = sample_size
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.time_proj = StableAudioGaussianFourierProjection(
embedding_size=time_proj_dim // 2,
flip_sin_to_cos=True,
log=False,
set_W_to_weight=False,
)
self.timestep_proj = nn.Sequential(
nn.Linear(time_proj_dim, self.inner_dim, bias=True),
nn.SiLU(),
nn.Linear(self.inner_dim, self.inner_dim, bias=True),
)
self.global_proj = nn.Sequential(
nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
nn.SiLU(),
nn.Linear(self.inner_dim, self.inner_dim, bias=False),
)
self.cross_attention_proj = nn.Sequential(
nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
nn.SiLU(),
nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
)
self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
self.transformer_blocks = nn.ModuleList(
[
StableAudioDiTBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
num_key_value_attention_heads=num_key_value_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for i in range(num_layers)
]
)
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(StableAudioAttnProcessor2_0())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
timestep: torch.LongTensor = None,
encoder_hidden_states: torch.FloatTensor = None,
global_hidden_states: torch.FloatTensor = None,
rotary_embedding: torch.FloatTensor = None,
return_dict: bool = True,
attention_mask: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`StableAudioDiTModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
Input `hidden_states`.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
Global embeddings that will be prepended to the hidden states.
rotary_embedding (`torch.Tensor`):
The rotary embeddings to apply on query and key tensors during attention calculation.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
Mask to avoid performing attention on padding token indices, formed by concatenating the attention
masks
for the two text encoders together. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
the attention masks
for the two text encoders together. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
global_hidden_states = self.global_proj(global_hidden_states)
time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype)))
global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
hidden_states = self.preprocess_conv(hidden_states) + hidden_states
# (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.proj_in(hidden_states)
# prepend global states to hidden states
hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
if attention_mask is not None:
prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
cross_attention_hidden_states,
encoder_attention_mask,
rotary_embedding,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=cross_attention_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_embedding=rotary_embedding,
)
hidden_states = self.proj_out(hidden_states)
# (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
# remove prepend length that has been added by global hidden states
hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
hidden_states = self.postprocess_conv(hidden_states) + hidden_states
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
......@@ -231,6 +231,10 @@ else:
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
"StableAudioProjectionModel",
"StableAudioPipeline",
]
_import_structure["stable_cascade"] = [
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
......@@ -533,6 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
from .stable_cascade import (
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"]
_import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_stable_audio import StableAudioProjectionModel
from .pipeline_stable_audio import StableAudioPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import pi
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, times: torch.Tensor) -> torch.Tensor:
times = times[..., None]
freqs = times * self.weights[None] * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered
@dataclass
class StableAudioProjectionModelOutput(BaseOutput):
"""
Args:
Class for StableAudio projection layer's outputs.
text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder.
seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the audio start hidden states.
seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the audio end hidden states.
"""
text_hidden_states: Optional[torch.Tensor] = None
seconds_start_hidden_states: Optional[torch.Tensor] = None
seconds_end_hidden_states: Optional[torch.Tensor] = None
class StableAudioNumberConditioner(nn.Module):
"""
A simple linear projection model to map numbers to a latent space.
Args:
number_embedding_dim (`int`):
Dimensionality of the number embeddings.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
internal_dim (`int`):
Dimensionality of the intermediate number hidden states.
"""
def __init__(
self,
number_embedding_dim,
min_value,
max_value,
internal_dim: Optional[int] = 256,
):
super().__init__()
self.time_positional_embedding = nn.Sequential(
StableAudioPositionalEmbedding(internal_dim),
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
)
self.number_embedding_dim = number_embedding_dim
self.min_value = min_value
self.max_value = max_value
def forward(
self,
floats: torch.Tensor,
):
floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
# Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
embedding = self.time_positional_embedding(normalized_floats)
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
return float_embeds
class StableAudioProjectionModel(ModelMixin, ConfigMixin):
"""
A simple linear projection model to map the conditioning values to a shared latent space.
Args:
text_encoder_dim (`int`):
Dimensionality of the text embeddings from the text encoder (T5).
conditioning_dim (`int`):
Dimensionality of the output conditioning tensors.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
"""
@register_to_config
def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value):
super().__init__()
self.text_projection = (
nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim)
)
self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
def forward(
self,
text_hidden_states: Optional[torch.Tensor] = None,
start_seconds: Optional[torch.Tensor] = None,
end_seconds: Optional[torch.Tensor] = None,
):
text_hidden_states = (
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
)
seconds_start_hidden_states = (
start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds)
)
seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds)
return StableAudioProjectionModelOutput(
text_hidden_states=text_hidden_states,
seconds_start_hidden_states=seconds_start_hidden_states,
seconds_end_hidden_states=seconds_end_hidden_states,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment