Commit 84bd65bc authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

parents 0deeb06a 97fcc4c6
...@@ -53,17 +53,23 @@ The class provides functionality to compute previous image according to alpha, b ...@@ -53,17 +53,23 @@ The class provides functionality to compute previous image according to alpha, b
## Quickstart ## Quickstart
### Installation
**Note**: If you want to run PyTorch on GPU on a CUDA-compatible machine, please make sure to install the corresponding `torch` version from the
[official website](https://pytorch.org/).
``` ```
git clone https://github.com/huggingface/diffusers.git git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install -e . cd diffusers && pip install -e .
``` ```
### 1. `diffusers` as a central modular diffusion and sampler library ### 1. `diffusers` as a toolbox for schedulers and models.
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases. `diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case. It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
Both models and schedulers should be load- and saveable from the Hub. Both models and schedulers should be load- and saveable from the Hub.
For more examples see [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) and [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)
#### **Example for [DDPM](https://arxiv.org/abs/2006.11239):** #### **Example for [DDPM](https://arxiv.org/abs/2006.11239):**
```python ```python
...@@ -171,25 +177,35 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -171,25 +177,35 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...) ### 2. `diffusers` as a collection of popula Diffusion systems (GLIDE, Dalle, ...)
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
#### **Example image generation with DDPM** #### **Example image generation with PNDM**
```python ```python
from diffusers import DiffusionPipeline from diffusers import PNDM, UNetModel, PNDMScheduler
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch
model_id = "fusing/ddim-celeba-hq"
model = UNetModel.from_pretrained(model_id)
scheduler = PNDMScheduler()
# load model and scheduler # load model and scheduler
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom") ddpm = PNDM(unet=model, noise_scheduler=scheduler)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
image = ddpm() with torch.no_grad():
image = ddpm()
# process image to PIL # process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5 image_processed = (image_processed + 1.0) / 2
image_processed = torch.clamp(image_processed, 0.0, 1.0)
image_processed = image_processed * 255
image_processed = image_processed.numpy().astype(np.uint8) image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
......
...@@ -144,9 +144,11 @@ if __name__ == "__main__": ...@@ -144,9 +144,11 @@ if __name__ == "__main__":
type=str, type=str,
default="no", default="no",
choices=["no", "fp16", "bf16"], choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose" help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.", "and an Nvidia Ampere GPU."
),
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -87,7 +87,6 @@ _deps = [ ...@@ -87,7 +87,6 @@ _deps = [
"regex!=2019.12.17", "regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"torchvision",
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
...@@ -172,7 +171,6 @@ install_requires = [ ...@@ -172,7 +171,6 @@ install_requires = [
deps["regex"], deps["regex"],
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["torchvision"],
deps["Pillow"], deps["Pillow"],
] ]
......
...@@ -6,10 +6,10 @@ __version__ = "0.0.3" ...@@ -6,10 +6,10 @@ __version__ = "0.0.3"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -13,5 +13,4 @@ deps = { ...@@ -13,5 +13,4 @@ deps = {
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchvision": "torchvision",
} }
...@@ -17,6 +17,6 @@ ...@@ -17,6 +17,6 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_ldm import UNetLDMModel
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
...@@ -26,7 +26,6 @@ from torch.optim import Adam ...@@ -26,7 +26,6 @@ from torch.optim import Adam
from torch.utils import data from torch.utils import data
from PIL import Image from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
# dataset classes
class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class Trainer(object):
def __init__(
self,
diffusion_model,
folder,
*,
ema_decay=0.995,
image_size=128,
train_batch_size=32,
train_lr=1e-4,
train_num_steps=100000,
gradient_accumulate_every=2,
amp=False,
step_start_ema=2000,
update_ema_every=10,
save_and_sample_every=1000,
results_folder="./results",
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.step_start_ema = step_start_ema
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.image_size = diffusion_model.image_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.ds = Dataset(folder, image_size)
self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True))
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
self.step = 0
self.amp = amp
self.scaler = GradScaler(enabled=amp)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok=True)
self.reset_parameters()
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
def save(self, milestone):
data = {
"step": self.step,
"model": self.model.state_dict(),
"ema": self.ema_model.state_dict(),
"scaler": self.scaler.state_dict(),
}
torch.save(data, str(self.results_folder / f"model-{milestone}.pt"))
def load(self, milestone):
data = torch.load(str(self.results_folder / f"model-{milestone}.pt"))
self.step = data["step"]
self.model.load_state_dict(data["model"])
self.ema_model.load_state_dict(data["ema"])
self.scaler.load_state_dict(data["scaler"])
def train(self):
with tqdm(initial=self.step, total=self.train_num_steps) as pbar:
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
with autocast(enabled=self.amp):
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
pbar.set_description(f"loss: {loss.item():.4f}")
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
if self.step % self.update_ema_every == 0:
self.step_ema()
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6)
self.save(milestone)
self.step += 1
pbar.update(1)
print("training complete")
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import torch import torch
try: try:
from einops import rearrange, repeat from einops import rearrange, repeat
except: except:
...@@ -11,6 +12,7 @@ except: ...@@ -11,6 +12,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
...@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module): ...@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
class Block(torch.nn.Module): class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8): def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__() super(Block, self).__init__()
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, self.block = torch.nn.Sequential(
padding=1), torch.nn.GroupNorm( torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
groups, dim_out), Mish()) )
def forward(self, x, mask): def forward(self, x, mask):
output = self.block(x * mask) output = self.block(x * mask)
...@@ -59,8 +61,7 @@ class Block(torch.nn.Module): ...@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
class ResnetBlock(torch.nn.Module): class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__() super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
dim_out))
self.block1 = Block(dim, dim_out, groups=groups) self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups)
...@@ -88,13 +89,11 @@ class LinearAttention(torch.nn.Module): ...@@ -88,13 +89,11 @@ class LinearAttention(torch.nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
heads = self.heads, qkv=3)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
heads=self.heads, h=h, w=w)
return self.to_out(out) return self.to_out(out)
...@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module): ...@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
class UNetGradTTSModel(ModelMixin, ConfigMixin): class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__( def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
self,
dim,
dim_mults=(1, 2, 4),
groups=8,
n_spks=None,
spk_emb_dim=64,
n_feats=80,
pe_scale=1000
):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
self.register( self.register(
...@@ -143,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -143,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
n_spks=n_spks, n_spks=n_spks,
spk_emb_dim=spk_emb_dim, spk_emb_dim=spk_emb_dim,
n_feats=n_feats, n_feats=n_feats,
pe_scale=pe_scale pe_scale=pe_scale,
) )
self.dim = dim self.dim = dim
...@@ -154,11 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -154,11 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.pe_scale = pe_scale self.pe_scale = pe_scale
if n_spks > 1: if n_spks > 1:
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim * 4, n_feats)) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim) self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
...@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(torch.nn.ModuleList([ self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim), ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity()])) Downsample(dim_out) if not is_last else torch.nn.Identity(),
]
)
)
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
...@@ -180,11 +175,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -180,11 +175,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(torch.nn.ModuleList([ self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in)])) Upsample(dim_in),
]
)
)
self.final_block = Block(dim, dim) self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1) self.final_conv = torch.nn.Conv2d(dim, 1, 1)
......
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM from .pipeline_pndm import PNDM
...@@ -24,11 +24,15 @@ import torch.utils.checkpoint ...@@ -24,11 +24,15 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import tqdm import tqdm
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer try:
from transformers.activations import ACT2FN from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
except:
print("Transformers is not installed")
pass
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -832,9 +836,7 @@ class GLIDE(DiffusionPipeline): ...@@ -832,9 +836,7 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = torch.randn( image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
).to(torch_device)
# 2. Encode tokens # 2. Encode tokens
# an empty input is needed to guide the model away from it # an empty input is needed to guide the model away from it
......
...@@ -39,14 +39,13 @@ def generate_path(duration, mask): ...@@ -39,14 +39,13 @@ def generate_path(duration, mask):
cum_duration_flat = cum_duration.view(b * t_x) cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y) path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
[1, 0], [0, 0]]))[:, :-1]
path = path * mask path = path * mask
return path return path
def duration_loss(logw, logw_, lengths): def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss return loss
...@@ -62,7 +61,7 @@ class LayerNorm(nn.Module): ...@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
n_dims = len(x.shape) n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True) mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True) variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps) x = (x - mean) * torch.rsqrt(variance + self.eps)
...@@ -72,8 +71,7 @@ class LayerNorm(nn.Module): ...@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module): class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
n_layers, p_dropout):
super(ConvReluNorm, self).__init__() super(ConvReluNorm, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
...@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module): ...@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
self.conv_layers = torch.nn.ModuleList() self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, self.conv_layers.append(
kernel_size, padding=kernel_size//2)) torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
...@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module): ...@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
kernel_size, padding=kernel_size//2)
self.norm_1 = LayerNorm(filter_channels) self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
kernel_size, padding=kernel_size//2)
self.norm_2 = LayerNorm(filter_channels) self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1) self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
...@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module): ...@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None, def __init__(
heads_share=True, p_dropout=0.0, proximal_bias=False, self,
proximal_init=False): channels,
out_channels,
n_heads,
window_size=None,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0 assert channels % n_heads == 0
...@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module): ...@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, self.emb_rel_k = torch.nn.Parameter(
window_size * 2 + 1, self.k_channels) * rel_stddev) torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, )
window_size * 2 + 1, self.k_channels) * rel_stddev) self.emb_rel_v = torch.nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
...@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module): ...@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
scores = scores + scores_local scores = scores + scores_local
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
dtype=scores.dtype)
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1) p_attn = torch.nn.functional.softmax(scores, dim=-1)
...@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module): ...@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn return output, p_attn
...@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module): ...@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
slice_end_position = slice_start_position + 2 * length - 1 slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0: if pad_length > 0:
padded_relative_embeddings = torch.nn.functional.pad( padded_relative_embeddings = torch.nn.functional.pad(
relative_embeddings, convert_pad_shape([[0, 0], relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])
[pad_length, pad_length], [0, 0]])) )
else: else:
padded_relative_embeddings = relative_embeddings padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:, used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
slice_start_position:slice_end_position]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final return x_final
def _attention_bias_proximal(self, length): def _attention_bias_proximal(self, length):
...@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module): ...@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
p_dropout=0.0):
super(FFN, self).__init__() super(FFN, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -267,10 +269,8 @@ class FFN(nn.Module): ...@@ -267,10 +269,8 @@ class FFN(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
padding=kernel_size//2) self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
padding=kernel_size//2)
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask): def forward(self, x, x_mask):
...@@ -282,8 +282,17 @@ class FFN(nn.Module): ...@@ -282,8 +282,17 @@ class FFN(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, def __init__(
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=None,
**kwargs,
):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
...@@ -299,11 +308,15 @@ class Encoder(nn.Module): ...@@ -299,11 +308,15 @@ class Encoder(nn.Module):
self.ffn_layers = torch.nn.ModuleList() self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers): for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, self.attn_layers.append(
n_heads, window_size=window_size, p_dropout=p_dropout)) MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, self.ffn_layers.append(
filter_channels, kernel_size, p_dropout=p_dropout)) FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
)
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask): def forward(self, x, x_mask):
...@@ -321,9 +334,21 @@ class Encoder(nn.Module): ...@@ -321,9 +334,21 @@ class Encoder(nn.Module):
class TextEncoder(ModelMixin, ConfigMixin): class TextEncoder(ModelMixin, ConfigMixin):
def __init__(self, n_vocab, n_feats, n_channels, filter_channels, def __init__(
filter_channels_dp, n_heads, n_layers, kernel_size, self,
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): n_vocab,
n_feats,
n_channels,
filter_channels,
filter_channels_dp,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=None,
spk_emb_dim=64,
n_spks=1,
):
super(TextEncoder, self).__init__() super(TextEncoder, self).__init__()
self.register( self.register(
...@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
p_dropout=p_dropout, p_dropout=p_dropout,
window_size=window_size, window_size=window_size,
spk_emb_dim=spk_emb_dim, spk_emb_dim=spk_emb_dim,
n_spks=n_spks n_spks=n_spks,
) )
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.n_feats = n_feats self.n_feats = n_feats
self.n_channels = n_channels self.n_channels = n_channels
...@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.emb = torch.nn.Embedding(n_vocab, n_channels) self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5)
kernel_size=5, n_layers=3, p_dropout=0.5)
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, self.encoder = Encoder(
kernel_size, p_dropout, window_size=window_size) n_channels + (spk_emb_dim if n_spks > 1 else 0),
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
)
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, self.proj_w = DurationPredictor(
kernel_size, p_dropout) n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
)
def forward(self, x, x_lengths, spk=None): def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels) x = self.emb(x) * math.sqrt(self.n_channels)
......
...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps] return self.warmup_time_steps[num_inference_steps]
...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at = alphas_cump[t + 1].view(-1, 1, 1, 1) at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) x_delta = (at_next - at) * (
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
)
x_next = x + x_delta x_next = x + x_delta
return x_next return x_next
......
...@@ -19,7 +19,18 @@ import unittest ...@@ -19,7 +19,18 @@ import unittest
import torch import torch
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers import (
BDDM,
DDIM,
DDPM,
GLIDE,
PNDM,
DDIMScheduler,
DDPMScheduler,
LatentDiffusion,
PNDMScheduler,
UNetModel,
)
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.pipelines.pipeline_bddm import DiffWave
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment