from functools import partial import numpy as np import torch import torch.nn.functional as F from base_model import BaseModel from einops import rearrange, repeat from sgm.modules.diffusionmodules.openaimodel import Timestep from sgm.modules.diffusionmodules.util import linear, timestep_embedding from sgm.util import instantiate_from_config from torch import nn from sat.model.base_model import non_conflict from sat.model.mixins import BaseMixin from sat.mpu.layers import ColumnParallelLinear from sat.ops.layernorm import LayerNorm, RMSNorm from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default class ImagePatchEmbeddingMixin(BaseMixin): def __init__( self, in_channels, hidden_size, patch_size, bias=True, text_hidden_size=None, ): super().__init__() self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) if text_hidden_size is not None: self.text_proj = nn.Linear(text_hidden_size, hidden_size) else: self.text_proj = None def word_embedding_forward(self, input_ids, **kwargs): # now is 3d patch images = kwargs['images'] # (b,t,c,h,w) B, T = images.shape[:2] emb = images.view(-1, *images.shape[2:]) emb = self.proj(emb) # ((b t),d,h/2,w/2) emb = emb.view(B, T, *emb.shape[1:]) emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) emb = rearrange(emb, 'b t n d -> b (t n) d') if self.text_proj is not None: text_emb = self.text_proj(kwargs['encoder_outputs']) emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d) emb = emb.contiguous() return emb # (b,n_t+t*n_i,d) def reinit(self, parent_model=None): w = self.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.proj.bias, 0) del self.transformer.word_embeddings import copy import os def get_3d_sincos_pos_embed( embed_dim, grid_height, grid_width, t_size, cls_token=False, height_interpolation=1.0, width_interpolation=1.0, time_interpolation=1.0, ): """ grid_size: int of the grid height and width t_size: int of the temporal size return: pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ assert embed_dim % 4 == 0 embed_dim_spatial = embed_dim // 4 * 3 embed_dim_temporal = embed_dim // 4 # spatial grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_height, grid_width]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( embed_dim_spatial, grid) # temporal grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( embed_dim_temporal, grid_t) # concate: [T, H, W] order pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] return pos_embed # [T, H*W, D] def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_height, dtype=np.float32) grid_w = np.arange(grid_width, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_height, grid_width]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate( [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class Basic2DPositionEmbeddingMixin(BaseMixin): def __init__(self, height, width, compressed_num_frames, hidden_size, text_length=0): super().__init__() self.height = height self.width = width self.spatial_length = height * width self.pos_embedding = nn.Parameter(torch.zeros( 1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False) def position_embedding_forward(self, position_ids, **kwargs): return self.pos_embedding def reinit(self, parent_model=None): del self.transformer.position_embeddings pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) self.pos_embedding.data[:, -self.spatial_length:].copy_( torch.from_numpy(pos_embed).float().unsqueeze(0)) class Basic3DPositionEmbeddingMixin(BaseMixin): def __init__( self, height, width, compressed_num_frames, hidden_size, text_length=0, height_interpolation=1.0, width_interpolation=1.0, time_interpolation=1.0, ): super().__init__() self.height = height self.width = width self.text_length = text_length self.compressed_num_frames = compressed_num_frames self.spatial_length = height * width self.num_patches = height * width * compressed_num_frames self.pos_embedding = nn.Parameter(torch.zeros( 1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False) self.height_interpolation = height_interpolation self.width_interpolation = width_interpolation self.time_interpolation = time_interpolation def position_embedding_forward(self, position_ids, **kwargs): if kwargs['images'].shape[1] == 1: return self.pos_embedding[:, :self.text_length + self.spatial_length] return self.pos_embedding[:, :self.text_length + kwargs['seq_length']] def reinit(self, parent_model=None): del self.transformer.position_embeddings pos_embed = get_3d_sincos_pos_embed( self.pos_embedding.shape[-1], self.height, self.width, self.compressed_num_frames, height_interpolation=self.height_interpolation, width_interpolation=self.width_interpolation, time_interpolation=self.time_interpolation, ) pos_embed = torch.from_numpy(pos_embed).float() pos_embed = rearrange(pos_embed, 't n d -> (t n) d') self.pos_embedding.data[:, -self.num_patches:].copy_(pos_embed) def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len( shape_lens) == 1, 'tensors must all have the same number of dimensions' shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims) ]), 'invalid dimensions for broadcastable concatentation' max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list( map(lambda t: (t[0], (t[1], ) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list( map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, '... d r -> ... (d r)') class Rotary3DPositionEmbeddingMixin(BaseMixin): def __init__( self, height, width, compressed_num_frames, hidden_size, hidden_size_head, text_length, theta=10000, h_interp_ratio=1.0, w_interp_ratio=1.0, t_interp_ratio=1.0, rot_v=False, learnable_pos_embed=False, ): super().__init__() self.rot_v = rot_v print(f'theta is {theta}') dim_t = hidden_size_head // 4 dim_h = hidden_size_head // 8 * 3 dim_w = hidden_size_head // 8 * 3 freqs_t = 1.0 / (theta**( torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t)) freqs_h = 1.0 / (theta**( torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h)) freqs_w = 1.0 / (theta**( torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w)) self.compressed_num_frames = compressed_num_frames self.height = height self.width = width grid_t = torch.arange(compressed_num_frames, dtype=torch.float32) grid_h = torch.arange(height, dtype=torch.float32) grid_w = torch.arange(width, dtype=torch.float32) if t_interp_ratio > 1.0: print(f't_interp_ratio is {t_interp_ratio}') grid_t = grid_t / t_interp_ratio if h_interp_ratio > 1.0: print(f'h_interp_ratio is {h_interp_ratio}') grid_h = grid_h / h_interp_ratio if w_interp_ratio > 1.0: print(f'w_interp_ratio is {w_interp_ratio}') grid_w = grid_w / w_interp_ratio freqs_t = torch.einsum('..., f -> ... f', grid_t, freqs_t) freqs_h = torch.einsum('..., f -> ... f', grid_h, freqs_h) freqs_w = torch.einsum('..., f -> ... f', grid_w, freqs_w) freqs_t = repeat(freqs_t, '... n -> ... (n r)', r=2) freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2) freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) freqs = rearrange(freqs, 't h w d -> (t h w) d') freqs = freqs.contiguous() freqs_sin = freqs.sin() freqs_cos = freqs.cos() self.register_buffer('freqs_sin', freqs_sin) self.register_buffer('freqs_cos', freqs_cos) self.text_length = text_length if learnable_pos_embed: num_patches = height * width * compressed_num_frames + text_length self.pos_embedding = nn.Parameter(torch.zeros( 1, num_patches, int(hidden_size)), requires_grad=True) else: self.pos_embedding = None def rotary(self, t, **kwargs): seq_len = t.shape[2] freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) return t * freqs_cos + rotate_half(t) * freqs_sin def position_embedding_forward(self, position_ids, **kwargs): return None def attention_fn( self, query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs, ): attention_fn_default = HOOKS_DEFAULT['attention_fn'] query_layer[:, :, self.text_length:] = self.rotary( query_layer[:, :, self.text_length:]) key_layer[:, :, self.text_length:] = self.rotary( key_layer[:, :, self.text_length:]) if self.rot_v: value_layer[:, :, self.text_length:] = self.rotary( value_layer[:, :, self.text_length:]) return attention_fn_default( query_layer, key_layer, value_layer, attention_mask, attention_dropout=attention_dropout, log_attention_weights=log_attention_weights, scaling_attention_score=scaling_attention_score, **kwargs, ) def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): """ x: (N, T/2 * S, patch_size**3 * C) imgs: (N, T, H, W, C) """ if rope_position_ids is not None: assert NotImplementedError # do pix2struct unpatchify L = x.shape[1] x = x.reshape(shape=(x.shape[0], L, p, p, c)) x = torch.einsum('nlpqc->ncplq', x) imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) else: b = x.shape[0] imgs = rearrange(x, 'b (t h w) (c p q) -> b t c (h p) (w q)', b=b, h=h, w=w, c=c, p=p, q=p) return imgs class FinalLayerMixin(BaseMixin): def __init__( self, hidden_size, time_embed_dim, patch_size, out_channels, latent_width, latent_height, elementwise_affine, ): super().__init__() self.hidden_size = hidden_size self.patch_size = patch_size self.out_channels = out_channels self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) self.spatial_length = latent_width * latent_height // patch_size**2 self.latent_width = latent_width self.latent_height = latent_height def final_forward(self, logits, **kwargs): x, emb = logits[:, kwargs['text_length']:, :], kwargs[ 'emb'] # x:(b,(t n),d) shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) if hasattr(self, 'share_cache') and 'mode' in self.share_cache: mode = self.share_cache['mode'] if mode == 'r': t, h, w = self.share_cache['shape_info'] target_h = h target_w = w else: assert mode == 'w' t, h, w = self.share_cache['ref_shape_info'] target_h = h target_w = w elif hasattr(self, 'share_cache') and 'shape_info' in self.share_cache: t, h, w = self.share_cache['shape_info'] target_h = h target_w = w else: target_h = self.latent_height // self.patch_size target_w = self.latent_width // self.patch_size return unpatchify( x, c=self.out_channels, p=self.patch_size, w=target_w, h=target_h, rope_position_ids=kwargs.get('rope_position_ids', None), **kwargs, ) def reinit(self, parent_model=None): nn.init.xavier_uniform_(self.linear.weight) nn.init.constant_(self.linear.bias, 0) class SwiGLUMixin(BaseMixin): def __init__(self, num_layers, in_features, hidden_features, bias=False): super().__init__() self.w2 = nn.ModuleList([ ColumnParallelLinear( in_features, hidden_features, gather_output=False, bias=bias, module=self, name='dense_h_to_4h_gate', ) for i in range(num_layers) ]) def mlp_forward(self, hidden_states, **kw_args): x = hidden_states origin = self.transformer.layers[kw_args['layer_id']].mlp x1 = origin.dense_h_to_4h(x) x2 = self.w2[kw_args['layer_id']](x) hidden = origin.activation_func(x2) * x1 x = origin.dense_4h_to_h(hidden) return x class AdaLNMixin(BaseMixin): def __init__( self, width, height, hidden_size, num_layers, time_embed_dim, compressed_num_frames, qk_ln=True, hidden_size_head=None, elementwise_affine=True, ): super().__init__() self.num_layers = num_layers self.width = width self.height = height self.compressed_num_frames = compressed_num_frames self.adaLN_modulations = nn.ModuleList([ nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers) ]) self.qk_ln = qk_ln if qk_ln: self.query_layernorm_list = nn.ModuleList([ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) for _ in range(num_layers) ]) self.key_layernorm_list = nn.ModuleList([ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) for _ in range(num_layers) ]) def layer_forward( self, hidden_states, mask, *args, **kwargs, ): # spatial attn here text_length = kwargs['text_length'] # hidden_states (b,(n_t+t*n_i),d) text_hidden_states = hidden_states[:, :text_length] # (b,n,d) img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d) layer = self.transformer.layers[kwargs['layer_id']] # if os.environ.get("DEBUGINFO", False): # print(f"in forward layer_id is {kwargs['layer_id']}", flush=True) adaLN_modulation = self.adaLN_modulations[kwargs['layer_id']] emb = kwargs['emb'] # if "size_emb" in self.share_cache: # size_emb = self.share_cache["size_emb"] # emb = emb + size_emb ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, text_shift_msa, text_scale_msa, text_gate_msa, text_shift_mlp, text_scale_mlp, text_gate_mlp, ) = adaLN_modulation(emb).chunk(12, dim=1) gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = ( gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1), text_gate_msa.unsqueeze(1), text_gate_mlp.unsqueeze(1), ) # self full attention (b,(t n),d) img_attention_input = layer.input_layernorm(img_hidden_states) text_attention_input = layer.input_layernorm(text_hidden_states) img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) attention_input = torch.cat( (text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) attention_output = layer.attention(attention_input, mask, **kwargs) text_attention_output = attention_output[:, :text_length] # (b,n,d) img_attention_output = attention_output[:, text_length:] # (b,(t n),d) if self.transformer.layernorm_order == 'sandwich': text_attention_output = layer.third_layernorm( text_attention_output) img_attention_output = layer.third_layernorm(img_attention_output) img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d) text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d) # mlp (b,(t n),d) img_mlp_input = layer.post_attention_layernorm( img_hidden_states) # vision (b,(t n),d) text_mlp_input = layer.post_attention_layernorm( text_hidden_states) # language (b,n,d) img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp) text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp) mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d mlp_output = layer.mlp(mlp_input, **kwargs) img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d) text_mlp_output = mlp_output[:, :text_length] # language (b,n,d) if self.transformer.layernorm_order == 'sandwich': text_mlp_output = layer.fourth_layernorm(text_mlp_output) img_mlp_output = layer.fourth_layernorm(img_mlp_output) img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) if 'scale_embedding' in self.share_cache: scaler_norm_layer = self.share_cache[ f'''scale_norm_layer_{kwargs["layer_id"]}'''] hidden_states = scaler_norm_layer(hidden_states, emb) # (b,(n_t+t*n_i),d) return hidden_states def reinit(self, parent_model=None): for layer in self.adaLN_modulations: nn.init.constant_(layer[-1].weight, 0) nn.init.constant_(layer[-1].bias, 0) @non_conflict def attention_fn( self, query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, old_impl=attention_fn_default, **kwargs, ): self.share_cache['temp_layer_id'] = kwargs['layer_id'] if self.qk_ln: query_layernorm = self.query_layernorm_list[kwargs['layer_id']] key_layernorm = self.key_layernorm_list[kwargs['layer_id']] query_layer = query_layernorm(query_layer) key_layer = key_layernorm(key_layer) # old_impl is attention_fn of Rotary3DPositionEmbeddingMixin return old_impl( query_layer, key_layer, value_layer, attention_mask, attention_dropout=attention_dropout, log_attention_weights=log_attention_weights, scaling_attention_score=scaling_attention_score, **kwargs, ) str_to_dtype = { 'fp32': torch.float32, 'fp16': torch.float16, 'bf16': torch.bfloat16 } class DiffusionTransformer(BaseModel): def __init__( self, transformer_args, num_frames, time_compressed_rate, latent_width, latent_height, patch_size, in_channels, out_channels, hidden_size, num_layers, num_attention_heads, elementwise_affine, time_embed_dim=None, num_classes=None, modules={}, input_time='adaln', adm_in_channels=None, parallel_output=True, height_interpolation=1.0, width_interpolation=1.0, time_interpolation=1.0, use_SwiGLU=False, use_RMSNorm=False, zero_init_y_embed=False, **kwargs, ): self.latent_width = latent_width self.latent_height = latent_height self.patch_size = patch_size self.num_frames = num_frames self.time_compressed_rate = time_compressed_rate self.spatial_length = latent_width * latent_height // patch_size**2 self.in_channels = in_channels self.out_channels = out_channels self.hidden_size = hidden_size self.model_channels = hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size self.num_classes = num_classes self.adm_in_channels = adm_in_channels self.input_time = input_time self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.is_decoder = transformer_args.is_decoder self.elementwise_affine = elementwise_affine self.height_interpolation = height_interpolation self.width_interpolation = width_interpolation self.time_interpolation = time_interpolation self.inner_hidden_size = hidden_size * 4 self.zero_init_y_embed = zero_init_y_embed try: self.dtype = str_to_dtype[kwargs.pop('dtype')] except: self.dtype = torch.float32 if use_SwiGLU: kwargs['activation_func'] = F.silu elif 'activation_func' not in kwargs: approx_gelu = nn.GELU(approximate='tanh') kwargs['activation_func'] = approx_gelu if use_RMSNorm: kwargs['layernorm'] = RMSNorm else: kwargs['layernorm'] = partial( LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) transformer_args.num_layers = num_layers transformer_args.hidden_size = hidden_size transformer_args.num_attention_heads = num_attention_heads transformer_args.parallel_output = parallel_output super().__init__(args=transformer_args, transformer=None, **kwargs) module_configs = modules self._build_modules(module_configs) if use_SwiGLU: self.add_mixin('swiglu', SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True) def _build_modules(self, module_configs): model_channels = self.hidden_size # time_embed_dim = model_channels * 4 time_embed_dim = self.time_embed_dim self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) elif self.num_classes == 'continuous': print('setting up linear c_adm embedding layer') self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == 'timestep': self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == 'sequential': assert self.adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(self.adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), )) if self.zero_init_y_embed: nn.init.constant_(self.label_emb[0][2].weight, 0) nn.init.constant_(self.label_emb[0][2].bias, 0) else: raise ValueError() pos_embed_config = module_configs['pos_embed_config'] self.add_mixin( 'pos_embed', instantiate_from_config( pos_embed_config, height=self.latent_height // self.patch_size, width=self.latent_width // self.patch_size, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, hidden_size=self.hidden_size, ), reinit=True, ) patch_embed_config = module_configs['patch_embed_config'] self.add_mixin( 'patch_embed', instantiate_from_config( patch_embed_config, patch_size=self.patch_size, hidden_size=self.hidden_size, in_channels=self.in_channels, ), reinit=True, ) if self.input_time == 'adaln': adaln_layer_config = module_configs['adaln_layer_config'] self.add_mixin( 'adaln_layer', instantiate_from_config( adaln_layer_config, height=self.latent_height // self.patch_size, width=self.latent_width // self.patch_size, hidden_size=self.hidden_size, num_layers=self.num_layers, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, hidden_size_head=self.hidden_size // self.num_attention_heads, time_embed_dim=self.time_embed_dim, elementwise_affine=self.elementwise_affine, ), ) else: raise NotImplementedError final_layer_config = module_configs['final_layer_config'] self.add_mixin( 'final_layer', instantiate_from_config( final_layer_config, hidden_size=self.hidden_size, patch_size=self.patch_size, out_channels=self.out_channels, time_embed_dim=self.time_embed_dim, latent_width=self.latent_width, latent_height=self.latent_height, elementwise_affine=self.elementwise_affine, ), reinit=True, ) if 'lora_config' in module_configs: lora_config = module_configs['lora_config'] self.add_mixin('lora', instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) return def forward(self, x, timesteps=None, context=None, y=None, **kwargs): b, t, d, h, w = x.shape if x.dtype != self.dtype: x = x.to(self.dtype) if 'ref_x' in self.share_cache: ref_x = self.share_cache['ref_x'] if ref_x.dtype != self.dtype: ref_x = ref_x.to(self.dtype) self.share_cache['ref_x'] = ref_x # This is not use in inference if 'concat_images' in kwargs and kwargs['concat_images'] is not None: if kwargs['concat_images'].shape[0] != x.shape[0]: concat_images = kwargs['concat_images'].repeat(2, 1, 1, 1, 1) else: concat_images = kwargs['concat_images'] x = torch.cat([x, concat_images], dim=2) assert (y is not None) == ( self.num_classes is not None ), 'must specify y if and only if the model is class-conditional' t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: # assert y.shape[0] == x.shape[0] assert x.shape[0] % y.shape[0] == 0 y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) emb = emb + self.label_emb(y) self.share_cache['shape_info'] = (t, h // (self.patch_size), w // (self.patch_size)) self.share_cache['timesteps'] = timesteps kwargs['seq_length'] = t * h * w // (self.patch_size**2) kwargs['images'] = x kwargs['emb'] = emb kwargs['encoder_outputs'] = context kwargs['text_length'] = context.shape[1] kwargs['input_ids'] = kwargs['position_ids'] = kwargs[ 'attention_mask'] = torch.ones((1, 1)).to(x.dtype) output = super().forward(**kwargs)[0] return output class RefDiffusionTransformer(DiffusionTransformer): def register_new_modules(self): all_layers = [] for n, m in self.named_modules(): if hasattr(m, 'attention'): all_layers.append(m) for m in all_layers: m.ref_query_key_value = copy.deepcopy(m.attention.query_key_value) m.ref_dense = copy.deepcopy(m.attention.dense) m.ref_attention_dropout = copy.deepcopy( m.attention.attention_dropout) m.ref_output_dropout = copy.deepcopy(m.attention.output_dropout) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): b, t, d, h, w = x.shape if x.dtype != self.dtype: x = x.to(self.dtype) assert (y is not None) == ( self.num_classes is not None ), 'must specify y if and only if the model is class-conditional' t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) ref_t_emb = timestep_embedding(torch.zeros_like(timesteps), self.model_channels, repeat_only=False, dtype=self.dtype) emb = self.time_embed(t_emb) ref_t_emb = self.time_embed(ref_t_emb) assert self.num_classes is None self.share_cache['timesteps'] = timesteps self.share_cache['shape_info'] = (t, h // (self.patch_size), w // (self.patch_size)) ref_x = self.share_cache['ref_x'] ref_b, ref_t, ref_d, ref_h, ref_w = ref_x.shape self.share_cache['ref_shape_info'] = (ref_t, ref_h // (self.patch_size), ref_w // (self.patch_size)) idx = kwargs.pop('idx') kwargs['seq_length'] = t * h * w // (self.patch_size**2) kwargs['images'] = x kwargs['emb'] = emb kwargs['encoder_outputs'] = context kwargs['text_length'] = context.shape[1] kwargs['input_ids'] = kwargs['position_ids'] = kwargs[ 'attention_mask'] = torch.ones((1, 1)).to(x.dtype) ref_kwargs = dict() ref_kwargs['seq_length'] = ref_t * ref_h * ref_w // (self.patch_size** 2) ref_kwargs['images'] = ref_x ref_kwargs['emb'] = ref_t_emb ref_kwargs['encoder_outputs'] = context ref_kwargs['text_length'] = context.shape[1] ref_kwargs['input_ids'] = ref_kwargs['position_ids'] = ref_kwargs[ 'attention_mask'] = torch.ones((1, 1)).to(x.dtype) self.share_cache['mode'] = 'w' super(DiffusionTransformer, self).forward(**ref_kwargs)[0] self.share_cache['mode'] = 'r' output = super(DiffusionTransformer, self).forward(**kwargs)[0] return output