# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # helpers functions import copy import math from functools import partial from inspect import isfunction from pathlib import Path import torch import torch.nn.functional as F from torch import einsum, nn from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data from einops import rearrange from PIL import Image from torchvision import transforms, utils from tqdm import tqdm from ...modeling_utils import PreTrainedModel from .configuration_unet import UNetConfig # NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def cycle(dl): while True: for data_dl in dl: yield data_dl def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def normalize_to_neg_one_to_one(img): return img * 2 - 1 def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 # small helper modules class EMA: def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x): x = self.norm(x) return self.fn(x) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) time_emb = rearrange(time_emb, "b c -> b c 1 1") scale_shift = time_emb.chunk(2, dim=1) h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) return h + self.res_conv(x) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) class UNetModel(PreTrainedModel): config_class = UNetConfig def __init__(self, config): super().__init__(config) init_dim = None out_dim = None channels = 3 with_time_emb = True resnet_block_groups = 8 learned_variance = False # determine dimensions dim_mults = config.dim_mults dim = config.dim self.channels = config.channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) block_klass = partial(ResnetBlock, groups=resnet_block_groups) # time embeddings if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) ) else: time_dim = None self.time_mlp = None # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.ModuleList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1)) def forward(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x) # gaussian diffusion trainer class def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def noise_like(shape, device, repeat=False): def repeat_noise(): return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) def noise(): return torch.randn(shape, device=device) return repeat_noise() if repeat else noise() def linear_beta_schedule(timesteps): scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) class GaussianDiffusion(nn.Module): def __init__( self, denoise_fn, *, image_size, channels=3, timesteps=1000, loss_type="l1", objective="pred_noise", beta_schedule="cosine", ): super().__init__() assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim) self.channels = channels self.image_size = image_size self.denoise_fn = denoise_fn self.objective = objective if beta_schedule == "linear": betas = linear_beta_schedule(timesteps) elif beta_schedule == "cosine": betas = cosine_beta_schedule(timesteps) else: raise ValueError(f"unknown beta schedule {beta_schedule}") alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type # helper function to register buffer from float64 to float32 def register_buffer(name, val): self.register_buffer(name, val.to(torch.float32)) register_buffer("betas", betas) register_buffer("alphas_cumprod", alphas_cumprod) register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer("posterior_variance", posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20))) register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) ) def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, clip_denoised: bool): model_output = self.denoise_fn(x, t) if self.objective == "pred_noise": x_start = self.predict_start_from_noise(x, t=t, noise=model_output) elif self.objective == "pred_x0": x_start = model_output else: raise ValueError(f"unknown objective {self.objective}") if clip_denoised: x_start.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return result @torch.no_grad() def p_sample_loop(self, shape): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm( reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps ): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) img = unnormalize_to_zero_to_one(img) return img @torch.no_grad() def sample(self, batch_size=16): image_size = self.image_size channels = self.channels return self.p_sample_loop((batch_size, channels, image_size, image_size)) @torch.no_grad() def interpolate(self, x1, x2, t=None, lam=0.5): b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) assert x1.shape == x2.shape t_batched = torch.stack([torch.tensor(t, device=device)] * b) xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) img = (1 - lam) * xt1 + lam * xt2 for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) return img def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) @property def loss_fn(self): if self.loss_type == "l1": return F.l1_loss elif self.loss_type == "l2": return F.mse_loss else: raise ValueError(f"invalid loss type {self.loss_type}") def p_losses(self, x_start, t, noise=None): b, c, h, w = x_start.shape noise = default(noise, lambda: torch.randn_like(x_start)) x = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.denoise_fn(x, t) if self.objective == "pred_noise": target = noise elif self.objective == "pred_x0": target = x_start else: raise ValueError(f"unknown objective {self.objective}") loss = self.loss_fn(model_out, target) return loss def forward(self, img, *args, **kwargs): b, _, h, w, device, img_size, = ( *img.shape, img.device, self.image_size, ) assert h == img_size and w == img_size, f"height and width of image must be {img_size}" t = torch.randint(0, self.num_timesteps, (b,), device=device).long() img = normalize_to_neg_one_to_one(img) return self.p_losses(img, t, *args, **kwargs) # dataset classes class Dataset(data.Dataset): def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] self.transform = transforms.Compose( [ transforms.Resize(image_size), transforms.RandomHorizontalFlip(), transforms.CenterCrop(image_size), transforms.ToTensor(), ] ) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] img = Image.open(path) return self.transform(img) # trainer class class Trainer(object): def __init__( self, diffusion_model, folder, *, ema_decay=0.995, image_size=128, train_batch_size=32, train_lr=1e-4, train_num_steps=100000, gradient_accumulate_every=2, amp=False, step_start_ema=2000, update_ema_every=10, save_and_sample_every=1000, results_folder="./results", ): super().__init__() self.model = diffusion_model self.ema = EMA(ema_decay) self.ema_model = copy.deepcopy(self.model) self.update_ema_every = update_ema_every self.step_start_ema = step_start_ema self.save_and_sample_every = save_and_sample_every self.batch_size = train_batch_size self.image_size = diffusion_model.image_size self.gradient_accumulate_every = gradient_accumulate_every self.train_num_steps = train_num_steps self.ds = Dataset(folder, image_size) self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True)) self.opt = Adam(diffusion_model.parameters(), lr=train_lr) self.step = 0 self.amp = amp self.scaler = GradScaler(enabled=amp) self.results_folder = Path(results_folder) self.results_folder.mkdir(exist_ok=True) self.reset_parameters() def reset_parameters(self): self.ema_model.load_state_dict(self.model.state_dict()) def step_ema(self): if self.step < self.step_start_ema: self.reset_parameters() return self.ema.update_model_average(self.ema_model, self.model) def save(self, milestone): data = { "step": self.step, "model": self.model.state_dict(), "ema": self.ema_model.state_dict(), "scaler": self.scaler.state_dict(), } torch.save(data, str(self.results_folder / f"model-{milestone}.pt")) def load(self, milestone): data = torch.load(str(self.results_folder / f"model-{milestone}.pt")) self.step = data["step"] self.model.load_state_dict(data["model"]) self.ema_model.load_state_dict(data["ema"]) self.scaler.load_state_dict(data["scaler"]) def train(self): with tqdm(initial=self.step, total=self.train_num_steps) as pbar: while self.step < self.train_num_steps: for i in range(self.gradient_accumulate_every): data = next(self.dl).cuda() with autocast(enabled=self.amp): loss = self.model(data) self.scaler.scale(loss / self.gradient_accumulate_every).backward() pbar.set_description(f"loss: {loss.item():.4f}") self.scaler.step(self.opt) self.scaler.update() self.opt.zero_grad() if self.step % self.update_ema_every == 0: self.step_ema() if self.step != 0 and self.step % self.save_and_sample_every == 0: self.ema_model.eval() milestone = self.step // self.save_and_sample_every batches = num_to_groups(36, self.batch_size) all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches)) all_images = torch.cat(all_images_list, dim=0) utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6) self.save(milestone) self.step += 1 pbar.update(1) print("training complete")