Commit ba21735c authored by anton-l's avatar anton-l
Browse files

DDPM training example

parent 2d1f7de2
......@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDMPipeline
from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
......@@ -225,11 +225,11 @@ class ConfigMixin:
text = reader.read()
return json.loads(text)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# def __eq__(self, other):
# return self.__dict__ == other.__dict__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
# def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> Dict[str, Any]:
......
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDMPipeline
......@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt")
upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = GLIDE(
text_unet=text2im_model,
......
......@@ -13,14 +13,16 @@
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
......@@ -46,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed = np.log(10000) / (half_dim - 1)
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
_embed = diffusion_steps * _embed
diffusion_step_embed = torch.cat((torch.sin(_embed),
torch.cos(_embed)), 1)
diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
return diffusion_step_embed
......@@ -67,8 +68,7 @@ class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
super().__init__()
self.padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
dilation=dilation, padding=self.padding)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight)
......@@ -94,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class ResidualBlock(nn.Module):
def __init__(self, res_channels, skip_channels, dilation,
diffusion_step_embed_dim_out):
def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
super().__init__()
self.res_channels = res_channels
......@@ -103,15 +102,12 @@ class ResidualBlock(nn.Module):
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
# Dilated conv layer
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
kernel_size=3, dilation=dilation)
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
# Add mel spectrogram upsampler and conditioner conv1x1 layer
self.upsample_conv2d = nn.ModuleList()
for s in [16, 16]:
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
padding=(1, s // 2),
stride=(1, s))
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
nn.init.kaiming_normal_(conv_trans2d.weight)
self.upsample_conv2d.append(conv_trans2d)
......@@ -157,7 +153,7 @@ class ResidualBlock(nn.Module):
h += mel_spec
# Gated-tanh nonlinearity
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
# Residual and skip outputs
res = self.res_conv(out)
......@@ -169,10 +165,16 @@ class ResidualBlock(nn.Module):
class ResidualGroup(nn.Module):
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
def __init__(
self,
res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
):
super().__init__()
self.num_res_layers = num_res_layers
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
......@@ -185,16 +187,19 @@ class ResidualGroup(nn.Module):
self.residual_blocks = nn.ModuleList()
for n in range(self.num_res_layers):
self.residual_blocks.append(
ResidualBlock(res_channels, skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
ResidualBlock(
res_channels,
skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
)
def forward(self, input_data):
x, mel_spectrogram, diffusion_steps = input_data
# Embed diffusion step t
diffusion_step_embed = calc_diffusion_step_embedding(
diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
......@@ -239,20 +244,24 @@ class DiffWave(ModelMixin, ConfigMixin):
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers
self.residual_layer = ResidualGroup(res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out)
self.residual_layer = ResidualGroup(
res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
)
# Final conv1x1 -> relu -> zeroconv1x1
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
self.final_conv = nn.Sequential(
Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False),
ZeroConv1d(skip_channels, out_channels),
)
def forward(self, input_data):
audio, mel_spectrogram, diffusion_steps = input_data
......@@ -267,12 +276,12 @@ class BDDMPipeline(DiffusionPipeline):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
......@@ -301,4 +310,4 @@ class BDDMPipeline(DiffusionPipeline):
# 4. set current audio to prev_audio: x_t -> x_t-1
audio = pred_prev_audio + variance
return audio
\ No newline at end of file
return audio
......@@ -28,12 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
......@@ -871,7 +866,12 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.upscale_unet.in_channels // 2, self.upscale_unet.resolution, self.upscale_unet.resolution),
(
batch_size,
self.upscale_unet.in_channels // 2,
self.upscale_unet.resolution,
self.upscale_unet.resolution,
),
generator=generator,
)
image = image.to(torch_device) * upsample_temp
......
......@@ -39,7 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
if trained_betas is not None:
......
......@@ -56,6 +56,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
......@@ -131,5 +133,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image
def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
return noisy_image
def __len__(self):
return self.timesteps
import random
import numpy as np
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(0)
accelerator = Accelerator(mixed_precision="fp16")
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 8
gradient_accumulation_steps = 8
augmentations = Compose(
[
Resize(64),
CenterCrop(64),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
# eval
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment