import math from inspect import isfunction import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample # from .resnet import ResBlock def exists(val): return val is not None def uniq(arr): return {el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) def convert_module_to_f16(l): """ Convert primitive modules to float16. """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() def convert_module_to_f32(l): """ Convert primitive modules to float32, undoing convert_module_to_f16(). """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.float() if l.bias is not None: l.bias.data = l.bias.data.float() def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) class GroupNorm32(nn.GroupNorm): def __init__(self, num_groups, num_channels, swish, eps=1e-5): super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) self.swish = swish def forward(self, x): y = super().forward(x.float()).to(x.dtype) if self.swish == 1.0: y = F.silu(y) elif self.swish: y = y * F.sigmoid(y * float(self.swish)) return y def normalization(channels, swish=0.0): """ Make a standard normalization layer, with an optional swish activation. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None): for layer in self: if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x def count_flops_attn(model, _x, y): """ A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: macs, params = thop.profile( model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, ) """ b, c, *spatial = y[0].shape num_spatial = int(np.prod(spatial)) # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += torch.DoubleTensor([matmul_ops]) class UNetLDMModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, use_spatial_transformer=False, # custom transformer support transformer_depth=1, # custom transformer support context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, ): super().__init__() # register all __init__ params with self.register self.register_to_config( image_size=image_size, in_channels=in_channels, model_channels=model_channels, out_channels=out_channels, num_res_blocks=num_res_blocks, attention_resolutions=attention_resolutions, dropout=dropout, channel_mult=channel_mult, conv_resample=conv_resample, dims=dims, num_classes=num_classes, use_checkpoint=use_checkpoint, use_fp16=use_fp16, num_heads=num_heads, num_head_channels=num_head_channels, num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, use_new_attention_order=use_new_attention_order, use_spatial_transformer=use_spatial_transformer, transformer_depth=transformer_depth, context_dim=context_dim, n_embed=n_embed, legacy=legacy, ) if use_spatial_transformer: assert ( context_dim is not None ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: assert ( use_spatial_transformer ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.dtype_ = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.predict_codebook_ids = n_embed is not None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResnetBlock( in_channels=ch, out_channels=mult * model_channels, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", overwrite_for_ldm=True, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op") ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResnetBlock( in_channels=ch, out_channels=None, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", overwrite_for_ldm=True, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), ResnetBlock( in_channels=ch, out_channels=None, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", overwrite_for_ldm=True, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResnetBlock( in_channels=ch + ich, out_channels=model_channels * mult, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", overwrite_for_ldm=True, ), ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) ) if level and i == num_res_blocks: out_ch = ch layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( normalization(ch), conv_nd(dims, model_channels, n_embed, 1), # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) self.output_blocks.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x.type(self.dtype_) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth) ] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention b, c, h, w = x.shape x_in = x x = self.norm(x) x = self.proj_in(x) x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) for block in self.transformer_blocks: x = block(x, context=context) x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) x = self.proj_out(x) return x + x_in class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def forward(self, x, context=None, mask=None): batch_size, sequence_length, dim = x.shape h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q = self.reshape_heads_to_batch_dim(q) k = self.reshape_heads_to_batch_dim(k) v = self.reshape_heads_to_batch_dim(v) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): mask = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = torch.einsum("b i j, b j d -> b i d", attn, v) out = self.reshape_batch_dim_to_heads(out) return self.to_out(out)