Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import math
import torch
def beta_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
# 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001:
betas = rescale_zero_terminal_snr(betas)
return betas
def sigma_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if schedule == 'logsnr_cosine_interp':
sigma = betas
else:
sigma = betas_to_sigmas(betas)
if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001:
sigma = rescale_zero_terminal_snr(sigma)
return sigma
def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
ast_beta = last_beta or scale * 0.02
return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
def logsnr_cosine_interp_schedule(
num_timesteps,
scale_min=2,
scale_max=4,
logsnr_min=-15,
logsnr_max=15,
**kwargs):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
# def cosine_schedule(n, cosine_s=0.008, **kwargs):
# ramp = torch.linspace(0, 1, n + 1)
# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
# return betas_to_sigmas(betas)
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def rescale_zero_terminal_snr(betas):
"""
Rescale Schedule to Zero Terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values. 8 alphas_bar_sqrt_0 = a
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
from .unet_i2vgen import *
from .unet_t2v import *
from .unet_higen import UNetSD_HiGen
from .unet_sr600 import UNetSD_SR600
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.nn.functional as F
import math
import os
import time
import numpy as np
import random
# from flash_attn.flash_attention import FlashAttention
class FlashAttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(FlashAttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
if self.head_dim <= 128 and (self.head_dim % 8) == 0:
new_scale = math.pow(head_dim, -0.5)
self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
# self.apply(self._init_weight)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
q = torch.cat([q, cq], dim=-1)
qkv = torch.cat([q,k,v], dim=1)
origin_dtype = qkv.dtype
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
out, _ = self.flash_attn(qkv)
out.to(origin_dtype)
if context is not None:
out = out[:, :-4, :, :]
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
# output
x = self.proj(out)
return x + identity
if __name__ == '__main__':
batch_size = 8
flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
# context = None
flash_net.eval()
with amp.autocast(enabled=True):
# warm up
for i in range(5):
y = flash_net(x, context)
torch.cuda.synchronize()
s1 = time.time()
for i in range(10):
y = flash_net(x, context)
torch.cuda.synchronize()
s2 = time.time()
print(f'Average cost time {(s2-s1)*1000/10} ms')
\ No newline at end of file
'''
/*
*Copyright (c) 2021, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
*/
'''
import math
import torch
import xformers
import xformers.ops
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from fairscale.nn.checkpoint import checkpoint_wrapper
from .util import *
from .mha_flash import FlashAttentionBlock
from utils.registry_class import MODEL
USE_TEMPORAL_TRANSFORMER = True
class ResBlockWoImg(ResBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
use_temporal_conv=True,
use_image_dataset=False,
):
super().__init__(channels, emb_channels, dropout, out_channels, use_conv, use_scale_shift_norm, dims, up, down, use_temporal_conv, use_image_dataset)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock_v2WoImg(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
# self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
class TemporalConvBlock_v2WoImg(TemporalConvBlock_v2):
def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False):
super(TemporalConvBlock_v2WoImg, self).__init__(in_dim, out_dim, dropout, use_image_dataset)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if x.size(2) == 1:
x = identity + 0.0 * x
else:
x = identity + x
return x
class TemporalTransformerWoImg(TemporalTransformer):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, only_self_att=True, multiply_zero=False):
super().__init__(in_channels, n_heads, d_head,
depth, dropout, context_dim,
disable_self_attn, use_linear,
use_checkpoint, only_self_att, multiply_zero)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if self.only_self_att:
context = None
if not isinstance(context, list):
context = [context]
b, c, f, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
x = self.proj_in(x)
# [16384, 16, 320]
if self.use_linear:
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
x = self.proj_in(x)
if self.only_self_att:
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
for i, block in enumerate(self.transformer_blocks):
x = block(x)
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
for i, block in enumerate(self.transformer_blocks):
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
x[j] = block(x[j], context=context_i_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
if not self.use_linear:
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
if x.size(2) == 1:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
class TextContextCrossTransformerMultiLayer(nn.Module):
def __init__(self, y_dim, depth, embed_dim, context_dim, num_tokens):
super(TextContextCrossTransformerMultiLayer, self).__init__()
self.context_transformer = nn.ModuleList(
[BasicTransformerBlock(embed_dim, n_heads=8, d_head=embed_dim//8, dropout=0.0, context_dim=embed_dim,
disable_self_attn=True, checkpoint=True)
for d in range(depth)]
)
self.input_mapping = nn.Linear(y_dim, embed_dim)
self.output_mapping = nn.Linear(embed_dim, context_dim)
scale = embed_dim ** -0.5
self.tokens = nn.Parameter(scale * torch.randn(1, num_tokens, embed_dim))
def forward(self, x):
x = self.input_mapping(x)
out = self.tokens.repeat(x.size(0), 1, 1)
for transformer in self.context_transformer:
out = transformer(out, context=x)
return self.output_mapping(out)
@MODEL.register_class()
class UNetSD_HiGen(nn.Module):
def __init__(self,
config=None,
in_dim=4,
dim=512,
y_dim=512,
context_dim=512,
hist_dim = 156,
dim_condition=4,
out_dim=6,
num_tokens=4,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=False,
use_image_dataset=False,
use_sim_mask = False,
training=True,
inpainting=True,
use_fps_condition=False,
p_all_zero=0.1,
p_all_keep=0.1,
zero_y=None,
adapter_transformer_layers=1,
context_embedding_depth=4,
**kwargs):
super(UNetSD_HiGen, self).__init__()
embed_dim = dim * 4
num_heads=num_heads if num_heads else dim//32
self.zero_y = zero_y
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.num_tokens = num_tokens
self.context_dim = context_dim
self.hist_dim = hist_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
### for temporal attention
self.num_heads = num_heads
### for spatial attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_sim_mask = use_sim_mask
self.training=training
self.inpainting = inpainting
self.p_all_zero = p_all_zero
self.p_all_keep = p_all_keep
self.use_fps_condition = use_fps_condition
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# Embedding
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), # [320,1280]
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
self.context_embedding = TextContextCrossTransformerMultiLayer(y_dim, context_embedding_depth, embed_dim, context_dim, num_tokens=self.num_tokens)
self.asim_embedding = nn.Sequential(
nn.Linear(32, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.asim_embedding[-1].weight)
nn.init.zeros_(self.asim_embedding[-1].bias)
self.msim_embedding = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.msim_embedding[-1].weight)
nn.init.zeros_(self.msim_embedding[-1].bias)
self.img_embedding = nn.Conv2d(self.in_dim, dim, 3, padding=1)
nn.init.zeros_(self.img_embedding.weight)
nn.init.zeros_(self.img_embedding.bias)
if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
self.rotary_emb = RotaryEmbedding(min(32, head_dim))
self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32)
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(TemporalTransformerWoImg(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([ResBlockWoImg(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset)])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(TemporalTransformerWoImg(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim
)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
self.middle_block = nn.ModuleList([
ResBlockWoImg(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformerWoImg(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
)
)
else:
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
self.middle_block.append(ResBlockWoImg(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([ResBlockWoImg(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformerWoImg(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
)
)
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
nn.init.zeros_(self.out[-1].weight)
def get_motion_embedding(self, batch, f, motion_cond):
if f > 1:
if motion_cond.size(1) != f:
motion_embedding = sinusoidal_embedding(motion_cond.flatten(0, 1), self.dim).view(batch, f-1, self.dim)
motion_embedding = torch.nn.functional.interpolate(motion_embedding.transpose(1, 2), size=(f), mode='linear').transpose(1, 2)
else:
motion_embedding = sinusoidal_embedding(motion_cond.flatten(0, 1), self.dim).view(batch, f, self.dim)
return self.msim_embedding(motion_embedding).flatten(0, 1)
else:
return self.msim_embedding(sinusoidal_embedding(motion_cond, self.dim))
def get_appearance_embedding(self, batch, f, appearance_cond):
return self.asim_embedding(appearance_cond).flatten(0, 1)
def forward(self,
x,
t,
y = None,
fps = None,
masked = None,
video_mask = None,
spat_prior = None,
motion_cond = None,
appearance_cond = None,
focus_present_mask = None,
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
mask_last_frame_num = 0, # mask last frame num
**kwargs):
assert self.inpainting or masked is None, 'inpainting is not supported'
batch, c, f, h, w= x.shape
device = x.device
self.batch = batch
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
else:
time_rel_pos_bias = None
# [Embeddings]
if self.use_fps_condition and fps is not None:
embeddings = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
else:
embeddings = self.time_embed(sinusoidal_embedding(t, self.dim))
embeddings = embeddings.repeat_interleave(repeats=f, dim=0)
embeddings = embeddings + self.get_motion_embedding(batch, f, motion_cond)
embeddings = embeddings + self.get_appearance_embedding(batch, f, appearance_cond)
# [Context]
context = self.context_embedding(y)
context = context.repeat_interleave(repeats=f, dim=0)
x = rearrange(x, 'b c f h w -> (b f) c h w')
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, embeddings, context, spat_prior, time_rel_pos_bias, focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, embeddings, context, spat_prior, time_rel_pos_bias,focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, embeddings, context, spat_prior, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x) # [32, 4, 32, 32]
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
return x
def _forward_single(self, module, x, e, context, spat_prior, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalTransformer_attemask):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, Upsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Downsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Resample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.Conv2d) and x.size(1) == self.in_dim:
x = module(x)
f = x.size(0) // self.batch
x = x + self.img_embedding(spat_prior).repeat_interleave(repeats=f, dim=0)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, spat_prior, time_rel_pos_bias, focus_present_mask, video_mask, reference)
else:
x = module(x)
return x
import math
import torch
import xformers
import xformers.ops
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from fairscale.nn.checkpoint import checkpoint_wrapper
from .util import *
# from .mha_flash import FlashAttentionBlock
from utils.registry_class import MODEL
USE_TEMPORAL_TRANSFORMER = True
@MODEL.register_class()
class UNetSD_I2VGen(nn.Module):
def __init__(self,
config=None,
in_dim=7,
dim=512,
y_dim=512,
context_dim=512,
hist_dim = 156,
concat_dim = 8,
dim_condition=4,
out_dim=6,
num_tokens=4,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=False,
use_image_dataset=False,
use_sim_mask = False,
training=True,
inpainting=True,
p_all_zero=0.1,
p_all_keep=0.1,
zero_y = None,
adapter_transformer_layers = 1,
**kwargs):
super(UNetSD_I2VGen, self).__init__()
embed_dim = dim * 4
num_heads=num_heads if num_heads else dim//32
self.zero_y = zero_y
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.num_tokens = num_tokens
self.context_dim = context_dim
self.hist_dim = hist_dim
self.concat_dim = concat_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
### for temporal attention
self.num_heads = num_heads
### for spatial attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_sim_mask = use_sim_mask
self.training=training
self.inpainting = inpainting
self.p_all_zero = p_all_zero
self.p_all_keep = p_all_keep
concat_dim = self.in_dim
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# Embedding
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), # [320,1280]
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.context_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, context_dim * self.num_tokens))
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
self.rotary_emb = RotaryEmbedding(min(32, head_dim))
self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet
# [Local Image embeding]
self.local_image_concat = nn.Sequential(
nn.Conv2d(4, concat_dim * 4, 3, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=1, padding=1))
self.local_temporal_encoder = TransformerV2(
heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim,
dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
self.local_image_embedding = nn.Sequential(
nn.Conv2d(4, concat_dim * 8, 3, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((32, 32)),
nn.Conv2d(concat_dim * 8, concat_dim * 16, 3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 16, 1024, 3, stride=2, padding=1))
# encoder
self.input_blocks = nn.ModuleList()
# init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)])
####need an initial temporal attention?
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
# elif temporal_conv:
# init_block.append(InitTemporalConvBlock(dim,dropout=dropout,use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)])
if scale in attn_scales:
# block.append(FlashAttentionBlock(out_dim, context_dim, num_heads, head_dim))
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
# block = nn.ModuleList([ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'downsample')])
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim
)
shortcut_dims.append(out_dim)
scale /= 2.0
# block.append(TemporalConvBlock(out_dim,dropout=dropout,use_image_dataset=use_image_dataset))
self.input_blocks.append(downsample)
self.middle_block = nn.ModuleList([
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
)
)
else:
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
)
)
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
def forward(self,
x,
t,
y = None,
image = None,
local_image = None,
masked = None,
fps = None,
video_mask = None,
focus_present_mask = None,
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
mask_last_frame_num = 0, # mask last frame num
**kwargs):
assert self.inpainting or masked is None, 'inpainting is not supported'
batch, c, f, h, w= x.shape
device = x.device
self.batch = batch
if local_image.ndim == 5 and local_image.size(2) > 1:
local_image = local_image[:, :, :1, ...]
elif local_image.ndim != 5:
local_image = local_image.unsqueeze(2)
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
else:
time_rel_pos_bias = None
# [Concat]
concat = x.new_zeros(batch, self.concat_dim, f, h, w)
if f > 1:
mask_pos = torch.cat([(torch.ones(local_image[:,:,:1].size())*( (tpos+1)/(f-1) )).cuda() for tpos in range(f-1)], dim=2)
_ximg = torch.cat([local_image[:,:,:1], mask_pos], dim=2)
_ximg = rearrange(_ximg, 'b c f h w -> (b f) c h w')
else:
_ximg = rearrange(local_image, 'b c f h w -> (b f) c h w')
_ximg = self.local_image_concat(_ximg)
_h = _ximg.shape[2]
_ximg = rearrange(_ximg, '(b f) c h w -> (b h w) f c', b = batch)
_ximg = self.local_temporal_encoder(_ximg)
_ximg = rearrange(_ximg, '(b h w) f c -> b c f h w', b = batch, h = _h)
concat += _ximg
concat += _ximg # TODO: This is a bug, but it doesn't matter.
# [Embeddings]
embeddings = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
embeddings = embeddings.repeat_interleave(repeats=f, dim=0)
# [Context]
# [C] for text input
context = x.new_zeros(batch, 0, self.context_dim)
if y is not None:
y_context = y
context = torch.cat([context, y_context], dim=1)
else:
y_context = self.zero_y.repeat(batch, 1, 1)[:, :1, :]
context = torch.cat([context, y_context], dim=1)
# [C] for local input
local_context = rearrange(local_image, 'b c f h w -> (b f) c h w')
local_context = self.local_image_embedding(local_context)
h = local_context.shape[2]
local_context = rearrange(local_context, 'b c h w -> b (h w) c', b = batch, h = h) # [12, 64, 1024]
context = torch.cat([context, local_context], dim=1)
# [C] for global input
if image is not None:
image_context = self.context_embedding(image)
image_context = image_context.view(-1, self.num_tokens, self.context_dim)
context = torch.cat([context, image_context], dim=1)
context = context.repeat_interleave(repeats=f, dim=0)
x = torch.cat([x, concat], dim=1)
x = rearrange(x, 'b c f h w -> (b f) c h w')
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias, focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias,focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x) # [32, 4, 32, 32]
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
return x
def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalTransformer_attemask):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, Upsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Downsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Resample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
else:
x = module(x)
return x
'''
/*
*Copyright (c) 2021, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
*/
'''
import math
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from .util import *
from utils.registry_class import MODEL
import torch.fft as fft
USE_TEMPORAL_TRANSFORMER = True
def Fourier_filter(x, threshold, scale):
dtype = x.dtype
x = x.type(torch.float32)
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()
crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
x_filtered = x_filtered.type(dtype)
return x_filtered
@MODEL.register_class()
class UNetSD_SR600(nn.Module):
def __init__(self,
in_dim=7,
dim=512,
y_dim=512,
context_dim=512,
out_dim=6,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=False,
use_image_dataset=False,
use_sim_mask = False,
inpainting=True,
**kwargs):
embed_dim = dim * 4
num_heads=num_heads if num_heads else dim//32
super(UNetSD_SR600, self).__init__()
self.in_dim = in_dim # 4
self.dim = dim # 320
self.y_dim = y_dim # 768
self.context_dim = context_dim # 1024
self.embed_dim = embed_dim # 1280
self.out_dim = out_dim # 4
self.dim_mult = dim_mult # [1, 2, 4, 4]
### for temporal attention
self.num_heads = num_heads # 8
### for spatial attention
self.head_dim = head_dim # 64
self.num_res_blocks = num_res_blocks # 2
self.attn_scales = attn_scales # [1.0, 0.5, 0.25]
self.use_scale_shift_norm = use_scale_shift_norm # True
self.temporal_attn_times = temporal_attn_times # 1
self.temporal_attention = temporal_attention # True
self.use_checkpoint = use_checkpoint # True
self.use_image_dataset = use_image_dataset # False
self.use_sim_mask = use_sim_mask # False
self.inpainting = inpainting # True
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult] # [320, 320, 640, 1280, 1280]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] # [1280, 1280, 1280, 640, 320]
shortcut_dims = []
scale = 1.0
# embeddings
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), # [320,1280]
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
####need an initial temporal attention?
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
# elif temporal_conv:
# init_block.append(InitTemporalConvBlock(dim,dropout=dropout,use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)])
if scale in attn_scales:
# block.append(FlashAttentionBlock(out_dim, context_dim, num_heads, head_dim))
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
# block = nn.ModuleList([ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'downsample')])
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim, padding=(2, 1)
)
shortcut_dims.append(out_dim)
scale /= 2.0
# block.append(TemporalConvBlock(out_dim,dropout=dropout,use_image_dataset=use_image_dataset))
self.input_blocks.append(downsample)
self.middle_block = nn.ModuleList([
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
)
)
else:
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
)
)
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = UpsampleSR600(out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
def forward(self,
x,
t,
y,
x_lr=None,
fps=None,
video_mask=None,
focus_present_mask = None,
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
mask_last_frame_num = 0 # mask last frame num
):
batch, x_c, x_f, x_h, x_w= x.shape
device = x.device
self.batch = batch
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device)) # [False, False]
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
else:
time_rel_pos_bias = None
# embeddings
e = self.time_embed(sinusoidal_embedding(t, self.dim)) #+ self.y_embedding(y)
context = y #self.context_embedding(y).view(-1, 4, self.context_dim)
# repeat f times for spatial e and context
e=e.repeat_interleave(repeats=x_f, dim=0)
context=context.repeat_interleave(repeats=x_f, dim=0)
# x = torch.cat([x, temp_x_lr], dim=1)
# x = x + temp_x_lr
## always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# encoder
xs = []
for idx, block in enumerate(self.input_blocks):
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
xs.append(x)
# print(f"encoder shape: {x.shape}")
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask)
# print(f"mid shape: {x.shape}")
# decoder
b_num = 0
for block in self.output_blocks:
# print(f"decoder shape: {x.shape}")
if b_num == 0:
temp_b, temp_c, _, _ = x.size()
x[:,:temp_c//2] = x[:,:temp_c//2] * 1.1
hs_ = xs.pop()
hs_ = Fourier_filter(hs_, threshold=1, scale=0.6)
x = torch.cat([x, hs_], dim=1)
elif b_num == 1:
temp_b, temp_c, _, _ = x.size()
x[:,:temp_c//2] = x[:,:temp_c//2] * 1.2
hs_ = xs.pop()
hs_ = Fourier_filter(hs_, threshold=1, scale=0.4)
x = torch.cat([x, hs_], dim=1)
else:
x = torch.cat([x, xs.pop()], dim=1)
# x = torch.cat([x, xs.pop()], dim=1)
b_num += 1
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x) # [32, 4, 32, 32]
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
return x
def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, Upsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Downsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Resample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
else:
x = module(x)
return x
if __name__ == '__main__':
# [model] unet
sd_model = UNetSDSR600(
in_dim=4,
dim=320,
y_dim=1024,
context_dim=1024,
out_dim=4,
dim_mult=[1, 2, 4, 4],
num_heads=8,
head_dim=64,
num_res_blocks=2,
attn_scales=[1 / 1, 1 / 2, 1 / 4],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=True,
use_image_dataset=False,
use_sim_mask = False,
inpainting=True,
training=False)
import math
import torch
import xformers
import xformers.ops
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from fairscale.nn.checkpoint import checkpoint_wrapper
from .util import *
# from .mha_flash import FlashAttentionBlock
from utils.registry_class import MODEL
USE_TEMPORAL_TRANSFORMER = True
@MODEL.register_class()
class UNetSD_T2VBase(nn.Module):
def __init__(self,
config=None,
in_dim=4,
dim=512,
y_dim=512,
context_dim=512,
hist_dim = 156,
dim_condition=4,
out_dim=6,
num_tokens=4,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=False,
use_image_dataset=False,
use_sim_mask = False,
training=True,
inpainting=True,
use_fps_condition=False,
p_all_zero=0.1,
p_all_keep=0.1,
zero_y = None,
adapter_transformer_layers = 1,
**kwargs):
super(UNetSD_T2VBase, self).__init__()
embed_dim = dim * 4
num_heads=num_heads if num_heads else dim//32
self.zero_y = zero_y
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.num_tokens = num_tokens
self.context_dim = context_dim
self.hist_dim = hist_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
### for temporal attention
self.num_heads = num_heads
### for spatial attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_sim_mask = use_sim_mask
self.training=training
self.inpainting = inpainting
self.p_all_zero = p_all_zero
self.p_all_keep = p_all_keep
self.use_fps_condition = use_fps_condition
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# Embedding
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), # [320,1280]
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
self.rotary_emb = RotaryEmbedding(min(32, head_dim))
self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32)
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset)])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim
)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
self.middle_block = nn.ModuleList([
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
)
)
else:
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
)
)
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
nn.init.zeros_(self.out[-1].weight)
def forward(self,
x,
t,
y = None,
fps = None,
masked = None,
video_mask = None,
focus_present_mask = None,
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
mask_last_frame_num = 0, # mask last frame num
**kwargs):
assert self.inpainting or masked is None, 'inpainting is not supported'
batch, c, f, h, w= x.shape
device = x.device
self.batch = batch
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
else:
time_rel_pos_bias = None
# [Embeddings]
if self.use_fps_condition and fps is not None:
embeddings = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
else:
embeddings = self.time_embed(sinusoidal_embedding(t, self.dim))
embeddings = embeddings.repeat_interleave(repeats=f, dim=0)
# [Context]
context = x.new_zeros(batch, 0, self.context_dim)
if y is not None:
y_context = y
context = torch.cat([context, y_context], dim=1)
else:
y_context = self.zero_y.repeat(batch, 1, 1)[:, :1, :]
context = torch.cat([context, y_context], dim=1)
context = context.repeat_interleave(repeats=f, dim=0)
x = rearrange(x, 'b c f h w -> (b f) c h w')
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias, focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias,focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, embeddings, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x) # [32, 4, 32, 32]
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
return x
def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalTransformer_attemask):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, Upsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Downsample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x)
elif isinstance(module, Resample):
# module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
else:
x = module(x)
return x
import math
import torch
import xformers
import open_clip
import xformers.ops
import torch.nn as nn
from torch import einsum
from einops import rearrange
from functools import partial
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from fairscale.nn.checkpoint import checkpoint_wrapper
# from .mha_flash import FlashAttentionBlock
from utils.registry_class import MODEL
### load all keys started with prefix and replace them with new_prefix
def load_Block(state, prefix, new_prefix=None):
if new_prefix is None:
new_prefix = prefix
state_dict = {}
state = {key:value for key,value in state.items() if prefix in key}
for key,value in state.items():
new_key = key.replace(prefix, new_prefix)
state_dict[new_key]=value
return state_dict
def load_2d_pretrained_state_dict(state,cfg):
new_state_dict = {}
dim = cfg.unet_dim
num_res_blocks = cfg.unet_res_blocks
temporal_attention = cfg.temporal_attention
temporal_conv = cfg.temporal_conv
dim_mult = cfg.unet_dim_mult
attn_scales = cfg.unet_attn_scales
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
#embeddings
state_dict = load_Block(state,prefix=f'time_embedding')
new_state_dict.update(state_dict)
state_dict = load_Block(state,prefix=f'y_embedding')
new_state_dict.update(state_dict)
state_dict = load_Block(state,prefix=f'context_embedding')
new_state_dict.update(state_dict)
encoder_idx = 0
### init block
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
new_state_dict.update(state_dict)
encoder_idx += 1
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
idx = 0
idx_ = 0
# residual (+attention) blocks
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
new_state_dict.update(state_dict)
idx += 1
idx_ = 2
if scale in attn_scales:
# block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
new_state_dict.update(state_dict)
# if temporal_attention:
# block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
in_dim = out_dim
encoder_idx += 1
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
# downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout)
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
new_state_dict.update(state_dict)
shortcut_dims.append(out_dim)
scale /= 2.0
encoder_idx += 1
# middle
# self.middle = nn.ModuleList([
# ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'),
# TemporalConvBlock(out_dim),
# AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
# if temporal_attention:
# self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
# elif temporal_conv:
# self.middle.append(TemporalConvBlock(out_dim,dropout=dropout))
# self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'))
# self.middle.append(TemporalConvBlock(out_dim))
# middle
middle_idx = 0
# self.middle = nn.ModuleList([
# ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout),
# AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
state_dict = load_Block(state,prefix=f'middle.{middle_idx}')
new_state_dict.update(state_dict)
middle_idx += 2
state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}')
new_state_dict.update(state_dict)
middle_idx += 1
for _ in range(cfg.temporal_attn_times):
# self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
middle_idx += 1
# self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout))
state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}')
new_state_dict.update(state_dict)
middle_idx += 2
decoder_idx = 0
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
idx = 0
idx_ = 0
# residual (+attention) blocks
# block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
new_state_dict.update(state_dict)
idx += 1
idx_ += 2
if scale in attn_scales:
# block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
new_state_dict.update(state_dict)
idx += 1
idx_ += 1
for _ in range(cfg.temporal_attn_times):
# block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
idx_ +=1
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
# upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout)
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
new_state_dict.update(state_dict)
idx += 1
idx_ += 2
scale *= 2.0
# block.append(upsample)
# self.decoder.append(block)
decoder_idx += 1
# head
# self.head = nn.Sequential(
# nn.GroupNorm(32, out_dim),
# nn.SiLU(),
# nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1)))
state_dict = load_Block(state,prefix=f'head')
new_state_dict.update(state_dict)
return new_state_dict
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps,
torch.pow(10000, -torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
### aviod mask all, which will cause find_unused_parameters error
if mask.all():
mask[0]=False
return mask
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.max_bs = max_bs
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
if q.shape[0] > self.max_bs:
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
out_list = []
for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
out = xformers.ops.memory_efficient_attention(
q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
out_list.append(out)
out = torch.cat(out_list, dim=0)
else:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class RelativePositionBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype = torch.long, device = device)
k_pos = torch.arange(n, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention_attemask(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock_attemask(nn.Module):
# ATTENTION_MODES = {
# "softmax": CrossAttention, # vanilla attention
# "softmax-xformers": MemoryEfficientCrossAttention
# }
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
# assert attn_mode in self.ATTENTION_MODES
# attn_cls = CrossAttention
attn_cls = MemoryEfficientCrossAttention_attemask
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward_(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerBlock(nn.Module):
# ATTENTION_MODES = {
# "softmax": CrossAttention, # vanilla attention
# "softmax-xformers": MemoryEfficientCrossAttention
# }
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
# assert attn_mode in self.ATTENTION_MODES
# attn_cls = CrossAttention
attn_cls = MemoryEfficientCrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward_(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class UpsampleSR600(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
# TODO: to match input_blocks, remove elements of two sides
x = x[..., 1:-1, :]
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
use_temporal_conv=True,
use_image_dataset=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
nn.Conv2d(channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
# self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
def forward(self, x, emb, batch_size):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return self._forward(x, emb, batch_size)
def _forward(self, x, emb, batch_size):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv:
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
h = self.temopral_conv(h)
# h = self.temopral_conv_2(h)
h = rearrange(h, 'b c f h w -> (b f) c h w')
return h
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, mode):
assert mode in ['none', 'upsample', 'downsample']
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.mode = mode
def forward(self, x, reference=None):
if self.mode == 'upsample':
assert reference is not None
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
elif self.mode == 'downsample':
x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:]))
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True,
mode='none', dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.mode = mode
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, mode)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e, reference=None):
identity = self.resample(x, reference)
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)
# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)
# output
x = self.proj(x)
return x + identity
class TemporalAttentionBlock(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
rotary_emb = None,
use_image_dataset = False,
use_sim_mask = False
):
super().__init__()
# consider num_heads first, as pos_bias needs fixed num_heads
# heads = dim // dim_head if dim_head else heads
dim_head = dim // heads
assert heads * dim_head == dim
self.use_image_dataset = use_image_dataset
self.use_sim_mask = use_sim_mask
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.norm = nn.GroupNorm(32, dim)
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False)
self.to_out = nn.Linear(hidden_dim, dim)#, bias = False)
# nn.init.zeros_(self.to_out.weight)
# nn.init.zeros_(self.to_out.bias)
def forward(
self,
x,
pos_bias = None,
focus_present_mask = None,
video_mask = None
):
identity = x
n, height, device = x.shape[2], x.shape[-2], x.device
x = self.norm(x)
x = rearrange(x, 'b c f h w -> b (h w) f c')
qkv = self.to_qkv(x).chunk(3, dim = -1)
if exists(focus_present_mask) and focus_present_mask.all():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
values = qkv[-1]
out = self.to_out(values)
out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
return out + identity
# split out heads
# q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads)
# shape [b (hw) h n c/h], n=f
q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads)
k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads)
v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads)
# scale
q = q * self.scale
# rotate positions into queries and keys for time attention
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# similarity
# shape [b (hw) h n n], n=f
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
# relative positional bias
if exists(pos_bias):
# print(sim.shape,pos_bias.shape)
sim = sim + pos_bias
if (focus_present_mask is None and video_mask is not None):
#video_mask: [B, n]
mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n]
mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n]
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
elif exists(focus_present_mask) and not (~focus_present_mask).all():
attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool)
attend_self_mask = torch.eye(n, device = device, dtype = torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
if self.use_sim_mask:
sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0)
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
# numerical stability
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
# aggregate values
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
out = self.to_out(out)
out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
if self.use_image_dataset:
out = identity + 0*out
else:
out = identity + out
return out
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, only_self_att=True, multiply_zero=False):
super().__init__()
self.multiply_zero = multiply_zero
self.only_self_att = only_self_att
self.use_adaptor = False
if self.only_self_att:
context_dim = None
if not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if self.use_adaptor:
self.adaptor_in = nn.Linear(frames, frames)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
checkpoint=use_checkpoint)
for d in range(depth)]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
if self.use_adaptor:
self.adaptor_out = nn.Linear(frames, frames)
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if self.only_self_att:
context = None
if not isinstance(context, list):
context = [context]
b, c, f, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
x = self.proj_in(x)
# [16384, 16, 320]
if self.use_linear:
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
x = self.proj_in(x)
if self.only_self_att:
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
for i, block in enumerate(self.transformer_blocks):
x = block(x)
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
for i, block in enumerate(self.transformer_blocks):
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
x[j] = block(x[j], context=context_i_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
if not self.use_linear:
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
if self.multiply_zero:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class PreNormattention(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
class TransformerV2(nn.Module):
def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
super().__init__()
self.layers = nn.ModuleList([])
self.depth = depth
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
FeedForward(dim, mlp_dim, dropout = dropout_ffn),
]))
def forward(self, x):
# if self.depth
for attn, ff in self.layers[:1]:
x = attn(x)
x = ff(x) + x
if self.depth > 1:
for attn, ff in self.layers[1:]:
x = attn(x)
x = ff(x) + x
return x
class TemporalTransformer_attemask(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, only_self_att=True, multiply_zero=False):
super().__init__()
self.multiply_zero = multiply_zero
self.only_self_att = only_self_att
self.use_adaptor = False
if self.only_self_att:
context_dim = None
if not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if self.use_adaptor:
self.adaptor_in = nn.Linear(frames, frames)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
checkpoint=use_checkpoint)
for d in range(depth)]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
if self.use_adaptor:
self.adaptor_out = nn.Linear(frames, frames)
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if self.only_self_att:
context = None
if not isinstance(context, list):
context = [context]
b, c, f, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
x = self.proj_in(x)
# [16384, 16, 320]
if self.use_linear:
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
x = self.proj_in(x)
if self.only_self_att:
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
for i, block in enumerate(self.transformer_blocks):
x = block(x)
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
for i, block in enumerate(self.transformer_blocks):
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
x[j] = block(x[j], context=context_i_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
if not self.use_linear:
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
if self.multiply_zero:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
class TemporalAttentionMultiBlock(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=32,
rotary_emb=None,
use_image_dataset=False,
use_sim_mask=False,
temporal_attn_times=1,
):
super().__init__()
self.att_layers = nn.ModuleList(
[TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask)
for _ in range(temporal_attn_times)]
)
def forward(
self,
x,
pos_bias = None,
focus_present_mask = None,
video_mask = None
):
for layer in self.att_layers:
x = layer(x, pos_bias, focus_present_mask, video_mask)
return x
class InitTemporalConvBlock(nn.Module):
def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False):
super(InitTemporalConvBlock, self).__init__()
if out_dim is None:
out_dim = in_dim#int(1.5*in_dim)
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
# zero out the last layer params,so the conv block is identity
# nn.init.zeros_(self.conv1[-1].weight)
# nn.init.zeros_(self.conv1[-1].bias)
nn.init.zeros_(self.conv[-1].weight)
nn.init.zeros_(self.conv[-1].bias)
def forward(self, x):
identity = x
x = self.conv(x)
if self.use_image_dataset:
x = identity + 0*x
else:
x = identity + x
return x
class TemporalConvBlock(nn.Module):
def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False):
super(TemporalConvBlock, self).__init__()
if out_dim is None:
out_dim = in_dim#int(1.5*in_dim)
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
# zero out the last layer params,so the conv block is identity
# nn.init.zeros_(self.conv1[-1].weight)
# nn.init.zeros_(self.conv1[-1].bias)
nn.init.zeros_(self.conv2[-1].weight)
nn.init.zeros_(self.conv2[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
if self.use_image_dataset:
x = identity + 0*x
else:
x = identity + x
return x
class TemporalConvBlock_v2(nn.Module):
def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False):
super(TemporalConvBlock_v2, self).__init__()
if out_dim is None:
out_dim = in_dim # int(1.5*in_dim)
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if self.use_image_dataset:
x = identity + 0.0 * x
else:
x = identity + x
return x
class DropPath(nn.Module):
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
"""
def __init__(self, p):
super(DropPath, self).__init__()
self.p = p
def forward(self, *args, zero=None, keep=None):
if not self.training:
return args[0] if len(args) == 1 else args
# params
x = args[0]
b = x.size(0)
n = (torch.rand(b) < self.p).sum()
# non-zero and non-keep mask
mask = x.new_ones(b, dtype=torch.bool)
if keep is not None:
mask[keep] = False
if zero is not None:
mask[zero] = False
# drop-path index
index = torch.where(mask)[0]
index = index[torch.randperm(len(index))[:n]]
if zero is not None:
index = torch.cat([index, torch.where(zero)[0]], dim=0)
# drop-path multiplier
multiplier = x.new_ones(b)
multiplier[index] = 0.0
output = tuple(u * self.broadcast(multiplier, u) for u in args)
return output[0] if len(args) == 1 else output
def broadcast(self, src, dst):
assert src.size(0) == dst.size(0)
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
return src.view(shape)
from .train_t2v_enterance import *
import os
import os.path as osp
import sys
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
import json
import math
import random
import torch
import logging
import datetime
import numpy as np
from PIL import Image
import torch.optim as optim
from einops import rearrange
import torch.cuda.amp as amp
from importlib import reload
from copy import deepcopy, copy
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import utils.transforms as data
from utils.util import to_device
from ..modules.config import cfg
from utils.seed import setup_seed
from utils.optim import AnnealingLR
from utils.multi_port import find_free_port
from utils.distributed import generalized_all_gather, all_reduce
from utils.registry_class import ENGINE, MODEL, DATASETS, EMBEDDER, AUTO_ENCODER, DISTRIBUTION, VISUAL, DIFFUSION, PRETRAIN
@ENGINE.register_function()
def train_t2v_entrance(cfg_update, **kwargs):
for k, v in cfg_update.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
if not 'MASTER_ADDR' in os.environ:
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']= find_free_port()
cfg.pmi_rank = int(os.getenv('RANK', 0)) # 0
cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
setup_seed(cfg.seed)
if cfg.debug:
cfg.gpus_per_machine = 1
cfg.world_size = 1
else:
cfg.gpus_per_machine = torch.cuda.device_count()
cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
if cfg.world_size == 1:
worker(0, cfg)
else:
mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, ))
return cfg
def worker(gpu, cfg):
'''
Training worker for each gpu
'''
cfg.gpu = gpu
cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu
if not cfg.debug:
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)
# [Log] Save logging
log_dir = generalized_all_gather(cfg.log_dir)[0]
exp_name = osp.basename(cfg.cfg_file).split('.')[0]
cfg.log_dir = osp.join(cfg.log_dir, exp_name)
os.makedirs(cfg.log_dir, exist_ok=True)
if cfg.rank == 0:
log_file = osp.join(cfg.log_dir, 'log.txt')
cfg.log_file = log_file
reload(logging)
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s: %(message)s',
handlers=[
logging.FileHandler(filename=log_file),
logging.StreamHandler(stream=sys.stdout)])
logging.info(cfg)
logging.info(f'Save all the file in to dir {cfg.log_dir}')
logging.info(f"Going into i2v_img_fullid_vidcom function on {gpu} gpu")
# [Diffusion] build diffusion settings
diffusion = DIFFUSION.build(cfg.Diffusion)
# [Dataset] imagedataset and videodataset
len_frames = len(cfg.frame_lens)
len_fps = len(cfg.sample_fps)
cfg.max_frames = cfg.frame_lens[cfg.rank % len_frames]
cfg.batch_size = cfg.batch_sizes[str(cfg.max_frames)]
cfg.sample_fps = cfg.sample_fps[cfg.rank % len_fps]
if cfg.rank == 0:
logging.info(f'Currnt worker with max_frames={cfg.max_frames}, batch_size={cfg.batch_size}, sample_fps={cfg.sample_fps}')
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
if cfg.max_frames == 1:
cfg.sample_fps = 1
dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
else:
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps, transforms=train_trans, vit_transforms=vit_trans, max_frames=cfg.max_frames)
sampler = DistributedSampler(dataset, num_replicas=cfg.world_size, rank=cfg.rank) if (cfg.world_size > 1 and not cfg.debug) else None
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
pin_memory=True,
prefetch_factor=cfg.prefetch_factor)
rank_iter = iter(dataloader)
# [Model] embedder
clip_encoder = EMBEDDER.build(cfg.embedder)
clip_encoder.model.to(gpu)
_, _, zero_y = clip_encoder(text="")
_, _, zero_y_negative = clip_encoder(text=cfg.negative_prompt)
zero_y, zero_y_negative = zero_y.detach(), zero_y_negative.detach()
# [Model] auotoencoder
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
autoencoder.eval() # freeze
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.cuda()
# [Model] UNet
model = MODEL.build(cfg.UNet, zero_y=zero_y_negative)
model = model.to(gpu)
resume_step = 1
model, resume_step = PRETRAIN.build(cfg.Pretrain, model=model)
torch.cuda.empty_cache()
if cfg.use_ema:
ema = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
ema = type(ema)([(k, ema[k].data.clone()) for k in list(ema.keys())[cfg.rank::cfg.world_size]])
# optimizer
optimizer = optim.AdamW(params=model.parameters(),
lr=cfg.lr, weight_decay=cfg.weight_decay)
scaler = amp.GradScaler(enabled=cfg.use_fp16)
if cfg.use_fsdp:
config = {}
config['compute_dtype'] = torch.float32
config['mixed_precision'] = True
model = FSDP(model, **config)
else:
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model.to(gpu)
# scheduler
scheduler = AnnealingLR(
optimizer=optimizer,
base_lr=cfg.lr,
warmup_steps=cfg.warmup_steps, # 10
total_steps=cfg.num_steps, # 200000
decay_mode=cfg.decay_mode) # 'cosine'
# [Visual]
viz_num = min(cfg.batch_size, 8)
visual_func = VISUAL.build(
cfg.visual_train,
cfg_global=cfg,
viz_num=viz_num,
diffusion=diffusion,
autoencoder=autoencoder)
for step in range(resume_step, cfg.num_steps + 1):
model.train()
try:
batch = next(rank_iter)
except StopIteration:
rank_iter = iter(dataloader)
batch = next(rank_iter)
batch = to_device(batch, gpu, non_blocking=True)
ref_frame, _, video_data, captions, video_key = batch
batch_size, frames_num, _, _, _ = video_data.shape
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
fps_tensor = torch.tensor([cfg.sample_fps] * batch_size, dtype=torch.long, device=gpu)
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
with torch.no_grad():
decode_data = []
for chunk_data in video_data_list:
latent_z = autoencoder.encode_firsr_stage(chunk_data, cfg.scale_factor).detach()
decode_data.append(latent_z) # [B, 4, 32, 56]
video_data = torch.cat(decode_data,dim=0)
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = batch_size) # [B, 4, 16, 32, 56]
opti_timesteps = getattr(cfg, 'opti_timesteps', cfg.Diffusion.schedule_param.num_timesteps)
t_round = torch.randint(0, opti_timesteps, (batch_size, ), dtype=torch.long, device=gpu) # 8
# preprocess
with torch.no_grad():
_, _, y_words = clip_encoder(text=captions) # bs * 1 *1024 [B, 1, 1024]
y_words_0 = y_words.clone()
try:
y_words[torch.rand(y_words.size(0)) < cfg.p_zero, :] = zero_y_negative
except:
pass
# forward
model_kwargs = {'y': y_words, 'fps': fps_tensor}
if cfg.use_fsdp:
loss = diffusion.loss(x0=video_data,
t=t_round, model=model, model_kwargs=model_kwargs,
use_div_loss=cfg.use_div_loss)
loss = loss.mean()
else:
with amp.autocast(enabled=cfg.use_fp16):
loss = diffusion.loss(
x0=video_data,
t=t_round,
model=model,
model_kwargs=model_kwargs,
use_div_loss=cfg.use_div_loss) # cfg.use_div_loss: False loss: [80]
loss = loss.mean()
# backward
if cfg.use_fsdp:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()
else:
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if not cfg.use_fsdp:
scheduler.step()
# ema update
if cfg.use_ema:
temp_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
for k, v in ema.items():
v.copy_(temp_state_dict[k].lerp(v, cfg.ema_decay))
all_reduce(loss)
loss = loss / cfg.world_size
if cfg.rank == 0 and step % cfg.log_interval == 0: # cfg.log_interval: 100
logging.info(f'Step: {step}/{cfg.num_steps} Loss: {loss.item():.3f} scale: {scaler.get_scale():.1f} LR: {scheduler.get_lr():.7f}')
# Visualization
if step == resume_step or step == cfg.num_steps or step % cfg.viz_interval == 0:
with torch.no_grad():
try:
visual_kwards = [
{
'y': y_words_0[:viz_num],
'fps': fps_tensor[:viz_num],
},
{
'y': zero_y_negative.repeat(y_words_0.size(0), 1, 1),
'fps': fps_tensor[:viz_num],
}
]
input_kwards = {
'model': model, 'video_data': video_data[:viz_num], 'step': step,
'ref_frame': ref_frame[:viz_num], 'captions': captions[:viz_num]}
visual_func.run(visual_kwards=visual_kwards, **input_kwards)
except Exception as e:
logging.info(f'Save videos with exception {e}')
# Save checkpoint
if step == cfg.num_steps or step % cfg.save_ckp_interval == 0 or step == resume_step:
os.makedirs(osp.join(cfg.log_dir, 'checkpoints'), exist_ok=True)
if cfg.use_ema:
local_ema_model_path = osp.join(cfg.log_dir, f'checkpoints/ema_{step:08d}_rank{cfg.rank:04d}.pth')
save_dict = {
'state_dict': ema.module.state_dict() if hasattr(ema, 'module') else ema,
'step': step}
torch.save(save_dict, local_ema_model_path)
if cfg.rank == 0:
logging.info(f'Begin to Save ema model to {local_ema_model_path}')
if cfg.rank == 0:
local_model_path = osp.join(cfg.log_dir, f'checkpoints/non_ema_{step:08d}.pth')
logging.info(f'Begin to Save model to {local_model_path}')
save_dict = {
'state_dict': model.module.state_dict() if not cfg.debug else model.state_dict(),
'step': step}
torch.save(save_dict, local_model_path)
logging.info(f'Save model to {local_model_path}')
if cfg.rank == 0:
logging.info('Congratulations! The training is completed!')
# synchronize to finish some processes
if not cfg.debug:
torch.cuda.synchronize()
dist.barrier()
import os
import sys
import copy
import json
import math
import random
import logging
import itertools
import numpy as np
from utils.config import Config
from utils.registry_class import ENGINE
from tools import *
if __name__ == '__main__':
cfg_update = Config(load=True)
ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment