# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # helpers functions import copy import math from functools import partial from inspect import isfunction from pathlib import Path import torch from torch import einsum, nn from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data from einops import rearrange from torchvision import utils, transforms from tqdm import tqdm from ..configuration_utils import Config from ..modeling_utils import PreTrainedModel from PIL import Image # NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def cycle(dl): while True: for data_dl in dl: yield data_dl def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def normalize_to_neg_one_to_one(img): return img * 2 - 1 def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 # small helper modules class EMA: def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x): x = self.norm(x) return self.fn(x) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) time_emb = rearrange(time_emb, "b c -> b c 1 1") scale_shift = time_emb.chunk(2, dim=1) h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) return h + self.res_conv(x) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) class UNetModel(PreTrainedModel, Config): def __init__( self, dim=64, dim_mults=(1, 2, 4, 8), init_dim=None, out_dim=None, channels=3, with_time_emb=True, resnet_block_groups=8, learned_variance=False, ): super().__init__() self.register( dim=dim, dim_mults=dim_mults, init_dim=init_dim, out_dim=out_dim, channels=channels, with_time_emb=with_time_emb, resnet_block_groups=resnet_block_groups, learned_variance=learned_variance, ) init_dim = None out_dim = None channels = 3 with_time_emb = True resnet_block_groups = 8 learned_variance = False # determine dimensions dim_mults = dim_mults dim = dim self.channels = channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) block_klass = partial(ResnetBlock, groups=resnet_block_groups) # time embeddings if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) ) else: time_dim = None self.time_mlp = None # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.ModuleList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1)) def forward(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x) # dataset classes class Dataset(data.Dataset): def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] self.transform = transforms.Compose( [ transforms.Resize(image_size), transforms.RandomHorizontalFlip(), transforms.CenterCrop(image_size), transforms.ToTensor(), ] ) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] img = Image.open(path) return self.transform(img) # trainer class class Trainer(object): def __init__( self, diffusion_model, folder, *, ema_decay=0.995, image_size=128, train_batch_size=32, train_lr=1e-4, train_num_steps=100000, gradient_accumulate_every=2, amp=False, step_start_ema=2000, update_ema_every=10, save_and_sample_every=1000, results_folder="./results", ): super().__init__() self.model = diffusion_model self.ema = EMA(ema_decay) self.ema_model = copy.deepcopy(self.model) self.update_ema_every = update_ema_every self.step_start_ema = step_start_ema self.save_and_sample_every = save_and_sample_every self.batch_size = train_batch_size self.image_size = diffusion_model.image_size self.gradient_accumulate_every = gradient_accumulate_every self.train_num_steps = train_num_steps self.ds = Dataset(folder, image_size) self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True)) self.opt = Adam(diffusion_model.parameters(), lr=train_lr) self.step = 0 self.amp = amp self.scaler = GradScaler(enabled=amp) self.results_folder = Path(results_folder) self.results_folder.mkdir(exist_ok=True) self.reset_parameters() def reset_parameters(self): self.ema_model.load_state_dict(self.model.state_dict()) def step_ema(self): if self.step < self.step_start_ema: self.reset_parameters() return self.ema.update_model_average(self.ema_model, self.model) def save(self, milestone): data = { "step": self.step, "model": self.model.state_dict(), "ema": self.ema_model.state_dict(), "scaler": self.scaler.state_dict(), } torch.save(data, str(self.results_folder / f"model-{milestone}.pt")) def load(self, milestone): data = torch.load(str(self.results_folder / f"model-{milestone}.pt")) self.step = data["step"] self.model.load_state_dict(data["model"]) self.ema_model.load_state_dict(data["ema"]) self.scaler.load_state_dict(data["scaler"]) def train(self): with tqdm(initial=self.step, total=self.train_num_steps) as pbar: while self.step < self.train_num_steps: for i in range(self.gradient_accumulate_every): data = next(self.dl).cuda() with autocast(enabled=self.amp): loss = self.model(data) self.scaler.scale(loss / self.gradient_accumulate_every).backward() pbar.set_description(f"loss: {loss.item():.4f}") self.scaler.step(self.opt) self.scaler.update() self.opt.zero_grad() if self.step % self.update_ema_every == 0: self.step_ema() if self.step != 0 and self.step % self.save_and_sample_every == 0: self.ema_model.eval() milestone = self.step // self.save_and_sample_every batches = num_to_groups(36, self.batch_size) all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches)) all_images = torch.cat(all_images_list, dim=0) utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6) self.save(milestone) self.step += 1 pbar.update(1) print("training complete")