Commit ba21735c authored by anton-l's avatar anton-l
Browse files

DDPM training example

parent 2d1f7de2
...@@ -9,6 +9,6 @@ from .models.unet import UNetModel ...@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDMPipeline from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -225,11 +225,11 @@ class ConfigMixin: ...@@ -225,11 +225,11 @@ class ConfigMixin:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
def __eq__(self, other): # def __eq__(self, other):
return self.__dict__ == other.__dict__ # return self.__dict__ == other.__dict__
def __repr__(self): # def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" # return f"{self.__class__.__name__} {self.to_json_string()}"
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
......
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDMPipeline
...@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel( ...@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt") upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = GLIDE( glide = GLIDE(
text_unet=text2im_model, text_unet=text2im_model,
......
...@@ -13,14 +13,16 @@ ...@@ -13,14 +13,16 @@
import math import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -46,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): ...@@ -46,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed = np.log(10000) / (half_dim - 1) _embed = np.log(10000) / (half_dim - 1)
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda() _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
_embed = diffusion_steps * _embed _embed = diffusion_steps * _embed
diffusion_step_embed = torch.cat((torch.sin(_embed), diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
torch.cos(_embed)), 1)
return diffusion_step_embed return diffusion_step_embed
...@@ -67,8 +68,7 @@ class Conv(nn.Module): ...@@ -67,8 +68,7 @@ class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
super().__init__() super().__init__()
self.padding = dilation * (kernel_size - 1) // 2 self.padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
dilation=dilation, padding=self.padding)
self.conv = nn.utils.weight_norm(self.conv) self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight) nn.init.kaiming_normal_(self.conv.weight)
...@@ -94,8 +94,7 @@ class ZeroConv1d(nn.Module): ...@@ -94,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper) # every residual block (named residual layer in paper)
# contains one noncausal dilated conv # contains one noncausal dilated conv
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, res_channels, skip_channels, dilation, def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
diffusion_step_embed_dim_out):
super().__init__() super().__init__()
self.res_channels = res_channels self.res_channels = res_channels
...@@ -103,15 +102,12 @@ class ResidualBlock(nn.Module): ...@@ -103,15 +102,12 @@ class ResidualBlock(nn.Module):
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
# Dilated conv layer # Dilated conv layer
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
kernel_size=3, dilation=dilation)
# Add mel spectrogram upsampler and conditioner conv1x1 layer # Add mel spectrogram upsampler and conditioner conv1x1 layer
self.upsample_conv2d = nn.ModuleList() self.upsample_conv2d = nn.ModuleList()
for s in [16, 16]: for s in [16, 16]:
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
padding=(1, s // 2),
stride=(1, s))
conv_trans2d = nn.utils.weight_norm(conv_trans2d) conv_trans2d = nn.utils.weight_norm(conv_trans2d)
nn.init.kaiming_normal_(conv_trans2d.weight) nn.init.kaiming_normal_(conv_trans2d.weight)
self.upsample_conv2d.append(conv_trans2d) self.upsample_conv2d.append(conv_trans2d)
...@@ -157,7 +153,7 @@ class ResidualBlock(nn.Module): ...@@ -157,7 +153,7 @@ class ResidualBlock(nn.Module):
h += mel_spec h += mel_spec
# Gated-tanh nonlinearity # Gated-tanh nonlinearity
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
# Residual and skip outputs # Residual and skip outputs
res = self.res_conv(out) res = self.res_conv(out)
...@@ -169,10 +165,16 @@ class ResidualBlock(nn.Module): ...@@ -169,10 +165,16 @@ class ResidualBlock(nn.Module):
class ResidualGroup(nn.Module): class ResidualGroup(nn.Module):
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, def __init__(
self,
res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in, diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid, diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out): diffusion_step_embed_dim_out,
):
super().__init__() super().__init__()
self.num_res_layers = num_res_layers self.num_res_layers = num_res_layers
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
...@@ -185,16 +187,19 @@ class ResidualGroup(nn.Module): ...@@ -185,16 +187,19 @@ class ResidualGroup(nn.Module):
self.residual_blocks = nn.ModuleList() self.residual_blocks = nn.ModuleList()
for n in range(self.num_res_layers): for n in range(self.num_res_layers):
self.residual_blocks.append( self.residual_blocks.append(
ResidualBlock(res_channels, skip_channels, ResidualBlock(
res_channels,
skip_channels,
dilation=2 ** (n % dilation_cycle), dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out)) diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
)
def forward(self, input_data): def forward(self, input_data):
x, mel_spectrogram, diffusion_steps = input_data x, mel_spectrogram, diffusion_steps = input_data
# Embed diffusion step t # Embed diffusion step t
diffusion_step_embed = calc_diffusion_step_embedding( diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
...@@ -239,20 +244,24 @@ class DiffWave(ModelMixin, ConfigMixin): ...@@ -239,20 +244,24 @@ class DiffWave(ModelMixin, ConfigMixin):
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
) )
# Initial conv1x1 with relu # Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False)) self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers # All residual layers
self.residual_layer = ResidualGroup(res_channels, self.residual_layer = ResidualGroup(
res_channels,
skip_channels, skip_channels,
num_res_layers, num_res_layers,
dilation_cycle, dilation_cycle,
diffusion_step_embed_dim_in, diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid, diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out) diffusion_step_embed_dim_out,
)
# Final conv1x1 -> relu -> zeroconv1x1 # Final conv1x1 -> relu -> zeroconv1x1
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), self.final_conv = nn.Sequential(
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels)) Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False),
ZeroConv1d(skip_channels, out_channels),
)
def forward(self, input_data): def forward(self, input_data):
audio, mel_spectrogram, diffusion_steps = input_data audio, mel_spectrogram, diffusion_steps = input_data
......
...@@ -28,12 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig ...@@ -28,12 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ( from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
ModelOutput,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -871,7 +866,12 @@ class GLIDE(DiffusionPipeline): ...@@ -871,7 +866,12 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.upscale_unet.in_channels // 2, self.upscale_unet.resolution, self.upscale_unet.resolution), (
batch_size,
self.upscale_unet.in_channels // 2,
self.upscale_unet.resolution,
self.upscale_unet.resolution,
),
generator=generator, generator=generator,
) )
image = image.to(torch_device) * upsample_temp image = image.to(torch_device) * upsample_temp
......
...@@ -56,6 +56,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,6 +56,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0) self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
...@@ -131,5 +133,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,5 +133,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image return pred_prev_image
def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
return noisy_image
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
import random
import numpy as np
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(0)
accelerator = Accelerator(mixed_precision="fp16")
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 8
gradient_accumulation_steps = 8
augmentations = Compose(
[
Resize(64),
CenterCrop(64),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
# eval
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment