# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # helpers functions import copy import math from pathlib import Path import torch from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data from PIL import Image from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample from .attention2d import AttnBlock, AttentionBlock def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h #class AttnBlock(nn.Module): # def __init__(self, in_channels): # super().__init__() # self.in_channels = in_channels # # self.norm = Normalize(in_channels) # self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # # def forward(self, x): # h_ = x # h_ = self.norm(h_) # q = self.q(h_) # k = self.k(h_) # v = self.v(h_) # # compute attention # b, c, h, w = q.shape # q = q.reshape(b, c, h * w) # q = q.permute(0, 2, 1) # b,hw,c # k = k.reshape(b, c, h * w) # b,c,hw # w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] # w_ = w_ * (int(c) ** (-0.5)) # w_ = torch.nn.functional.softmax(w_, dim=2) # # attend to values # v = v.reshape(b, c, h * w) # w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) # h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] # h_ = h_.reshape(b, c, h, w) # # h_ = self.proj_out(h_) # # return x + h_ class UNetModel(ModelMixin, ConfigMixin): def __init__( self, ch=128, out_ch=3, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,), dropout=0.0, resamp_with_conv=True, in_channels=3, resolution=256, ): super().__init__() self.register_to_config( ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=in_channels, resolution=resolution, ) ch_mult = tuple(ch_mult) self.ch = ch self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( [ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ] ) # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() attn_2 = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout ) ) block_in = block_out if curr_res in attn_resolutions: # attn.append(AttnBlock(block_in)) attn.append(AttentionBlock(block_in, overwrite_qkv=True)) down = nn.Module() down.block = block down.attn = attn down.attn_2 = attn_2 if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout ) # self.mid.attn_1 = AttnBlock(block_in) self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] skip_in = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( ResnetBlock( in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: # attn.append(AttnBlock(block_in)) attn.append(AttentionBlock(block_in, overwrite_qkv=True)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, use_conv=resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, timesteps): assert x.shape[2] == x.shape[3] == self.resolution if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) # timestep embedding temb = get_timestep_embedding(timesteps, self.ch) temb = self.temb.dense[0](temb) temb = nonlinearity(temb) temb = self.temb.dense[1](temb) # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) # h = self.down[i_level].attn_2[i_block](h) h = self.down[i_level].attn[i_block](h) # print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h