Unverified Commit 88fa6b7d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dance Diffusion] Add dance diffusion (#803)



* start

* add more logic

* Update src/diffusers/models/unet_2d_condition_flax.py

* match weights

* up

* make model work

* making class more general, fixing missed file rename

* small fix

* make new conversion work

* up

* finalize conversion

* up

* first batch of variable renamings

* remove c and c_prev var names

* add mid and out block structure

* add pipeline

* up

* finish conversion

* finish

* upload

* more fixes

* Apply suggestions from code review

* add attr

* up

* uP

* up

* finish tests

* finish

* uP

* finish

* fix test

* up

* naming consistency in tests

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarNathan Lambert <nathan@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* remove hardcoded 16

* Remove bogus

* fix some stuff

* finish

* improve logging

* docs

* upload
Co-authored-by: default avatarNathan Lambert <nol@berkeley.edu>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarNathan Lambert <nathan@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 0b42b074
...@@ -92,5 +92,7 @@ ...@@ -92,5 +92,7 @@
title: "Stable Diffusion" title: "Stable Diffusion"
- local: api/pipelines/stochastic_karras_ve - local: api/pipelines/stochastic_karras_ve
title: "Stochastic Karras VE" title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
title: "Pipelines" title: "Pipelines"
title: "API" title: "API"
...@@ -22,6 +22,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ...@@ -22,6 +22,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput ## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet2DModel ## UNet2DModel
[[autodoc]] UNet2DModel [[autodoc]] UNet2DModel
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Dance Diffusion
## Overview
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) by Zach Evans.
Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai.
For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page.
The original codebase of this implementation can be found [here](https://github.com/Harmonai-org/sample-generator).
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_dance_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py) | *Unconditional Audio Generation* | - |
## DanceDiffusionPipeline
[[autodoc]] DanceDiffusionPipeline
- __call__
...@@ -95,6 +95,10 @@ Original paper can be found [here](https://arxiv.org/abs/2011.13456). ...@@ -95,6 +95,10 @@ Original paper can be found [here](https://arxiv.org/abs/2011.13456).
[[autodoc]] ScoreSdeVeScheduler [[autodoc]] ScoreSdeVeScheduler
#### improved pseudo numerical methods for diffusion models (iPNDM)
Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).
#### variance preserving stochastic differential equation (SDE) scheduler #### variance preserving stochastic differential equation (SDE) scheduler
Original paper can be found [here](https://arxiv.org/abs/2011.13456). Original paper can be found [here](https://arxiv.org/abs/2011.13456).
......
#!/usr/bin/env python3
import argparse
import math
import os
from copy import deepcopy
import torch
from torch import nn
from audio_diffusion.models import DiffusionAttnUnet1D
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusion import sampling
MODELS_MAP = {
"gwf-440k": {
"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
"sample_rate": 48000,
"sample_size": 65536,
},
"jmann-small-190k": {
"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
"sample_rate": 48000,
"sample_size": 65536,
},
"jmann-large-580k": {
"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
"sample_rate": 48000,
"sample_size": 131072,
},
"maestro-uncond-150k": {
"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
"unlocked-uncond-250k": {
"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
"honk-140k": {
"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
}
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma**2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)
class Object(object):
pass
class DiffusionUncond(nn.Module):
def __init__(self, global_args):
super().__init__()
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
def download(model_name):
url = MODELS_MAP[model_name]["url"]
os.system(f"wget {url} ./")
return f"./{model_name}.ckpt"
DOWN_NUM_TO_LAYER = {
"1": "resnets.0",
"2": "attentions.0",
"3": "resnets.1",
"4": "attentions.1",
"5": "resnets.2",
"6": "attentions.2",
}
UP_NUM_TO_LAYER = {
"8": "resnets.0",
"9": "attentions.0",
"10": "resnets.1",
"11": "attentions.1",
"12": "resnets.2",
"13": "attentions.2",
}
MID_NUM_TO_LAYER = {
"1": "resnets.0",
"2": "attentions.0",
"3": "resnets.1",
"4": "attentions.1",
"5": "resnets.2",
"6": "attentions.2",
"8": "resnets.3",
"9": "attentions.3",
"10": "resnets.4",
"11": "attentions.4",
"12": "resnets.5",
"13": "attentions.5",
}
DEPTH_0_TO_LAYER = {
"0": "resnets.0",
"1": "resnets.1",
"2": "resnets.2",
"4": "resnets.0",
"5": "resnets.1",
"6": "resnets.2",
}
RES_CONV_MAP = {
"skip": "conv_skip",
"main.0": "conv_1",
"main.1": "group_norm_1",
"main.3": "conv_2",
"main.4": "group_norm_2",
}
ATTN_MAP = {
"norm": "group_norm",
"qkv_proj": ["query", "key", "value"],
"out_proj": ["proj_attn"],
}
def convert_resconv_naming(name):
if name.startswith("skip"):
return name.replace("skip", RES_CONV_MAP["skip"])
# name has to be of format main.{digit}
if not name.startswith("main."):
raise ValueError(f"ResConvBlock error with {name}")
return name.replace(name[:6], RES_CONV_MAP[name[:6]])
def convert_attn_naming(name):
for key, value in ATTN_MAP.items():
if name.startswith(key) and not isinstance(value, list):
return name.replace(key, value)
elif name.startswith(key):
return [name.replace(key, v) for v in value]
raise ValueError(f"Attn error with {name}")
def rename(input_string, max_depth=13):
string = input_string
if string.split(".")[0] == "timestep_embed":
return string.replace("timestep_embed", "time_proj")
depth = 0
if string.startswith("net.3."):
depth += 1
string = string[6:]
elif string.startswith("net."):
string = string[4:]
while string.startswith("main.7."):
depth += 1
string = string[7:]
if string.startswith("main."):
string = string[5:]
# mid block
if string[:2].isdigit():
layer_num = string[:2]
string_left = string[2:]
else:
layer_num = string[0]
string_left = string[1:]
if depth == max_depth:
new_layer = MID_NUM_TO_LAYER[layer_num]
prefix = "mid_block"
elif depth > 0 and int(layer_num) < 7:
new_layer = DOWN_NUM_TO_LAYER[layer_num]
prefix = f"down_blocks.{depth}"
elif depth > 0 and int(layer_num) > 7:
new_layer = UP_NUM_TO_LAYER[layer_num]
prefix = f"up_blocks.{max_depth - depth - 1}"
elif depth == 0:
new_layer = DEPTH_0_TO_LAYER[layer_num]
prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
if not string_left.startswith("."):
raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
string_left = string_left[1:]
if "resnets" in new_layer:
string_left = convert_resconv_naming(string_left)
elif "attentions" in new_layer:
new_string_left = convert_attn_naming(string_left)
string_left = new_string_left
if not isinstance(string_left, list):
new_string = prefix + "." + new_layer + "." + string_left
else:
new_string = [prefix + "." + new_layer + "." + s for s in string_left]
return new_string
def rename_orig_weights(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.endswith("kernel"):
# up- and downsample layers, don't have trainable weights
continue
new_k = rename(k)
# check if we need to transform from Conv => Linear for attention
if isinstance(new_k, list):
new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
else:
new_state_dict[new_k] = v
return new_state_dict
def transform_conv_attns(new_state_dict, new_k, v):
if len(new_k) == 1:
if len(v.shape) == 3:
# weight
new_state_dict[new_k[0]] = v[:, :, 0]
else:
# bias
new_state_dict[new_k[0]] = v
else:
# qkv matrices
trippled_shape = v.shape[0]
single_shape = trippled_shape // 3
for i in range(3):
if len(v.shape) == 3:
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
else:
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
return new_state_dict
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
assert (
model_name == args.model_path
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
sample_size = MODELS_MAP[model_name]["sample_size"]
config = Object()
config.sample_size = sample_size
config.sample_rate = sample_rate
config.latent_dim = 0
diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
diffusers_state_dict = diffusers_model.state_dict()
orig_model = DiffusionUncond(config)
orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
orig_model = orig_model.diffusion_ema.eval()
orig_model_state_dict = orig_model.state_dict()
renamed_state_dict = rename_orig_weights(orig_model_state_dict)
renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
assert (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
if key == "time_proj.weight":
value = value.squeeze()
diffusers_state_dict[key] = value
diffusers_model.load_state_dict(diffusers_state_dict)
steps = 100
seed = 33
diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
generator = torch.manual_seed(seed)
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
step_list = get_crash_schedule(t)
pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
generator = torch.manual_seed(33)
audio = pipe(num_inference_steps=steps, generator=generator).audios
generated = sampling.iplms_sample(orig_model, noise, step_list, {})
generated = generated.clamp(-1, 1)
diff_sum = (generated - audio).abs().sum()
diff_max = (generated - audio).abs().max()
if args.save:
pipe.save_pretrained(args.checkpoint_path)
print("Diff sum", diff_sum)
print("Diff max", diff_max)
assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
print(f"Conversion for {model_name} successful!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument(
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
)
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
...@@ -18,7 +18,7 @@ from .utils import logging ...@@ -18,7 +18,7 @@ from .utils import logging
if is_torch_available(): if is_torch_available():
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import ( from .optimization import (
get_constant_schedule, get_constant_schedule,
get_constant_schedule_with_warmup, get_constant_schedule_with_warmup,
...@@ -29,10 +29,19 @@ if is_torch_available(): ...@@ -29,10 +29,19 @@ if is_torch_available():
get_scheduler, get_scheduler,
) )
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline from .pipelines import (
DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
PNDMPipeline,
ScoreSdeVePipeline,
)
from .schedulers import ( from .schedulers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
IPNDMScheduler,
KarrasVeScheduler, KarrasVeScheduler,
PNDMScheduler, PNDMScheduler,
SchedulerMixin, SchedulerMixin,
......
...@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available ...@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
if is_torch_available(): if is_torch_available():
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel from .vae import AutoencoderKL, VQModel
......
...@@ -101,17 +101,28 @@ class Timesteps(nn.Module): ...@@ -101,17 +101,28 @@ class Timesteps(nn.Module):
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels.""" """Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size: int = 256, scale: float = 1.0): def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
# to delete later if set_W_to_weight:
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) # to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W self.weight = self.W
def forward(self, x): def forward(self, x):
x = torch.log(x) if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out return out
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
@dataclass
class UNet1DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
Hidden states output. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
"""
@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
mid_block_type: str = "UNetMidBlock1D",
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
block_out_channels: Tuple[int] = (32, 32, 64),
):
super().__init__()
self.sample_size = sample_size
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
down_block = get_down_block(
down_block_type,
in_channels=input_channel,
out_channels=output_channel,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type=mid_block_type,
mid_channels=block_out_channels[-1],
in_channels=block_out_channels[-1],
out_channels=None,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
up_block = get_up_block(
up_block_type,
in_channels=prev_output_channel,
out_channels=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
# Totally fine to add another layer with a if statement - no need for nn.Identity here
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if len(timestep.shape) == 0:
timestep = timestep[None]
timestep_embed = self.time_proj(timestep)[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples
# 3. mid
sample = self.mid_block(sample)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_samples)
if not return_dict:
return (sample,)
return UNet1DOutput(sample=sample)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn.functional as F
from torch import nn
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
class Downsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states):
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
class SelfAttention1d(nn.Module):
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
self.num_heads = n_head
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores, dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.dropout(hidden_states)
output = hidden_states + residual
return output
class ResConvBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
self.gelu_1 = nn.GELU()
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states):
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
hidden_states = self.conv_2(hidden_states)
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
output = hidden_states + residual
return output
def get_down_block(down_block_type, out_channels, in_channels):
if down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, in_channels, out_channels):
if up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
if mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states):
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states
...@@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass @dataclass
......
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
CrossAttnUpBlock2D, CrossAttnUpBlock2D,
DownBlock2D, DownBlock2D,
......
...@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config ...@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .unet_blocks_flax import ( from .unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D, FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D, FlaxCrossAttnUpBlock2D,
FlaxDownBlock2D, FlaxDownBlock2D,
......
...@@ -21,7 +21,7 @@ import torch.nn as nn ...@@ -21,7 +21,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass @dataclass
......
...@@ -93,6 +93,20 @@ class ImagePipelineOutput(BaseOutput): ...@@ -93,6 +93,20 @@ class ImagePipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
@dataclass
class AudioPipelineOutput(BaseOutput):
"""
Output class for audio pipelines.
Args:
audios (`np.ndarray`)
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
denoised audio samples of the diffusion pipeline.
"""
audios: np.ndarray
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
r""" r"""
Base class for all models. Base class for all models.
......
...@@ -2,6 +2,7 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is ...@@ -2,6 +2,7 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is
if is_torch_available(): if is_torch_available():
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline from .latent_diffusion_uncond import LDMPipeline
......
# flake8: noqa
from .pipeline_dance_diffusion import DanceDiffusionPipeline
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DanceDiffusionPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`IPNDMScheduler`].
"""
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 100,
generator: Optional[torch.Generator] = None,
sample_length_in_s: Optional[float] = None,
return_dict: bool = True,
) -> Union[AudioPipelineOutput, Tuple]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of audio samples to generate.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
the expense of slower inference.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if sample_length_in_s is None:
sample_length_in_s = self.unet.sample_size / self.unet.sample_rate
sample_size = sample_length_in_s * self.unet.sample_rate
down_scale_factor = 2 ** len(self.unet.up_blocks)
if sample_size < 3 * down_scale_factor:
raise ValueError(
f"{sample_length_in_s} is too small. Make sure it's bigger or equal to"
f" {3 * down_scale_factor / self.unet.sample_rate}."
)
original_sample_size = int(sample_size)
if sample_size % down_scale_factor != 0:
sample_size = ((sample_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
logger.info(
f"{sample_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
" process."
)
sample_size = int(sample_size)
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(audio, t).sample
# 2. compute previous image: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample
audio = audio.clamp(-1, 1).cpu().numpy()
audio = audio[:, :, :original_sample_size]
if not return_dict:
return (audio,)
return AudioPipelineOutput(audios=audio)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment