Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import functools
import torch.nn as nn
from ..util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
try:
nn.init.normal_(m.weight.data, 0.0, 0.02)
except:
nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
import hashlib
import os
import requests
import torch
import torch.nn as nn
from tqdm import tqdm
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class ActNorm(nn.Module):
def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:, :, None, None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
import torch
import torch.nn.functional as F
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
import copy
from pathlib import Path
from math import log2, ceil, sqrt
from functools import wraps, partial
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad
import torchvision
from torchvision.models import VGG16_Weights
from collections import namedtuple
# from vector_quantize_pytorch import LFQ, FSQ
from .regularizers.finite_scalar_quantization import FSQ
from .regularizers.lookup_free_quantization import LFQ
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__
from gateloop_transformer import SimpleGateLoopLayer
from taylor_series_linear_attention import TaylorSeriesLinearAttn
from kornia.filters import filter3d
import pickle
# helper
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def safe_get_index(it, ind, default=None):
if ind < len(it):
return it[ind]
return default
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def identity(t, *args, **kwargs):
return t
def divisible_by(num, den):
return (num % den) == 0
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def append_dims(t, ndims: int):
return t.reshape(*t.shape, *((1,) * ndims))
def is_odd(n):
return not divisible_by(n, 2)
def maybe_del_attr_(o, attr):
if hasattr(o, attr):
delattr(o, attr)
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
# tensor helpers
def l2norm(t):
return F.normalize(t, dim=-1, p=2)
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
def pick_video_frame(video, frame_indices):
batch, device = video.shape[0], video.device
video = rearrange(video, "b c f ... -> b f c ...")
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, "b -> b 1")
images = video[batch_indices, frame_indices]
images = rearrange(images, "b 1 c ... -> b c ...")
return images
# gan related
def gradient_penalty(images, output):
batch_size = images.shape[0]
gradients = torch_grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
# helper decorators
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, "vgg")
if has_vgg:
vgg = self.vgg
delattr(self, "vgg")
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# helper classes
def Sequential(*modules):
modules = [*filter(exists, modules)]
if len(modules) == 0:
return nn.Identity()
return nn.Sequential(*modules)
class Residual(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back
class ToTimeSequence(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
x = rearrange(x, "b c f ... -> b ... f c")
x, ps = pack_one(x, "* n c")
o = self.fn(x, **kwargs)
o = unpack_one(o, ps, "* n c")
return rearrange(o, "b ... f c -> b c f ...")
class SqueezeExcite(Module):
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
super().__init__()
dim_out = default(dim_out, dim)
self.to_k = nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)
self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid()
)
nn.init.zeros_(self.net[-2].weight)
nn.init.constant_(self.net[-2].bias, init_bias)
def forward(self, x):
orig_input, batch = x, x.shape[0]
is_video = x.ndim == 5
if is_video:
x = rearrange(x, "b c f h w -> (b f) c h w")
context = self.to_k(x)
context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1)
spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)")
out = einsum("b i n, b c n -> b c i", context, spatial_flattened_input)
out = rearrange(out, "... -> ... 1")
gates = self.net(out)
if is_video:
gates = rearrange(gates, "(b f) c h w -> b c f h w", b=batch)
return gates * orig_input
# token shifting
class TokenShift(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
x, x_shift = x.chunk(2, dim=1)
x_shift = pad_at_dim(x_shift, (1, -1), dim=2) # shift time dimension
x = torch.cat((x, x_shift), dim=1)
return self.fn(x, **kwargs)
# rmsnorm
class RMSNorm(Module):
def __init__(self, dim, channel_first=False, images=False, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class AdaptiveRMSNorm(Module):
def __init__(self, dim, *, dim_cond, channel_first=False, images=False, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.dim_cond = dim_cond
self.channel_first = channel_first
self.scale = dim**0.5
self.to_gamma = nn.Linear(dim_cond, dim)
self.to_bias = nn.Linear(dim_cond, dim) if bias else None
nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)
if bias:
nn.init.zeros_(self.to_bias.weight)
nn.init.zeros_(self.to_bias.bias)
@beartype
def forward(self, x: Tensor, *, cond: Tensor):
batch = x.shape[0]
assert cond.shape == (batch, self.dim_cond)
gamma = self.to_gamma(cond)
bias = 0.0
if exists(self.to_bias):
bias = self.to_bias(cond)
if self.channel_first:
gamma = append_dims(gamma, x.ndim - 2)
if exists(self.to_bias):
bias = append_dims(bias, x.ndim - 2)
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * gamma + bias
# attention
class Attention(Module):
@beartype
def __init__(
self,
*,
dim,
dim_cond: Optional[int] = None,
causal=False,
dim_head=32,
heads=8,
flash=False,
dropout=0.0,
num_memory_kv=4,
):
super().__init__()
dim_inner = dim_head * heads
self.need_cond = exists(dim_cond)
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
else:
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads)
)
assert num_memory_kv > 0
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
@beartype
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
x = self.norm(x, **maybe_cond_kwargs)
q, k, v = self.to_qkv(x)
mk, mv = map(lambda t: repeat(t, "h n d -> b h n d", b=q.shape[0]), self.mem_kv)
k = torch.cat((mk, k), dim=-2)
v = torch.cat((mv, v), dim=-2)
out = self.attend(q, k, v, mask=mask)
return self.to_out(out)
class LinearAttention(Module):
"""
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
"""
@beartype
def __init__(self, *, dim, dim_cond: Optional[int] = None, dim_head=8, heads=8, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.need_cond = exists(dim_cond)
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
else:
self.norm = RMSNorm(dim)
self.attn = TaylorSeriesLinearAttn(dim=dim, dim_head=dim_head, heads=heads)
def forward(self, x, cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
x = self.norm(x, **maybe_cond_kwargs)
return self.attn(x)
class LinearSpaceAttention(LinearAttention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, "b c ... h w -> b ... h w c")
x, batch_ps = pack_one(x, "* h w c")
x, seq_ps = pack_one(x, "b * c")
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, "b * c")
x = unpack_one(x, batch_ps, "* h w c")
return rearrange(x, "b ... h w c -> b c ... h w")
class SpaceAttention(Attention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, "b c t h w -> b t h w c")
x, batch_ps = pack_one(x, "* h w c")
x, seq_ps = pack_one(x, "b * c")
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, "b * c")
x = unpack_one(x, batch_ps, "* h w c")
return rearrange(x, "b t h w c -> b c t h w")
class TimeAttention(Attention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, "b c t h w -> b h w t c")
x, batch_ps = pack_one(x, "* t c")
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, batch_ps, "* t c")
return rearrange(x, "b h w t c -> b c t h w")
class GEGLU(Module):
def forward(self, x):
x, gate = x.chunk(2, dim=1)
return F.gelu(gate) * x
class FeedForward(Module):
@beartype
def __init__(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False):
super().__init__()
conv_klass = nn.Conv2d if images else nn.Conv3d
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
dim_inner = int(dim * mult * 2 / 3)
self.norm = maybe_adaptive_norm_klass(dim)
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1))
@beartype
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if exists(cond) else dict()
x = self.norm(x, **maybe_cond_kwargs)
return self.net(x)
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class Blur(Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer("f", f)
def forward(self, x, space_only=False, time_only=False):
assert not (space_only and time_only)
f = self.f
if space_only:
f = einsum("i, j -> i j", f, f)
f = rearrange(f, "... -> 1 1 ...")
elif time_only:
f = rearrange(f, "f -> 1 f 1 1")
else:
f = einsum("i, j, k -> i j k", f, f, f)
f = rearrange(f, "... -> 1 ...")
is_images = x.ndim == 4
if is_images:
x = rearrange(x, "b c h w -> b c 1 h w")
out = filter3d(x, f, normalized=True)
if is_images:
out = rearrange(out, "b c 1 h w -> b c h w")
return out
class DiscriminatorBlock(Module):
def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = (
nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
)
if downsample
else None
)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator(Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
channels=3,
max_dim=512,
attn_heads=8,
attn_dim_head=32,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
blocks = []
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
attn_blocks = []
image_resolution = min_image_resolution
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
)
attn_block = Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange("b ... -> b (...)"),
nn.Linear(latent_dim, 1),
Rearrange("b 1 -> b"),
)
def forward(self, x):
for block, attn_block in self.blocks:
x = block(x)
x = attn_block(x)
return self.to_logits(x)
# modulatable conv from Karras et al. Stylegan2
# for conditioning on latents
class Conv3DMod(Module):
@beartype
def __init__(
self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros"
):
super().__init__()
dim_out = default(dim_out, dim)
self.eps = eps
assert is_odd(spatial_kernel) and is_odd(time_kernel)
self.spatial_kernel = spatial_kernel
self.time_kernel = time_kernel
time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2)
self.pad_mode = pad_mode
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
self.demod = demod
nn.init.kaiming_normal_(self.weights, a=0, mode="fan_in", nonlinearity="selu")
@beartype
def forward(self, fmap, cond: Tensor):
"""
notation
b - batch
n - convs
o - output
i - input
k - kernel
"""
b = fmap.shape[0]
# prepare weights for modulation
weights = self.weights
# do the modulation, demodulation, as done in stylegan2
cond = rearrange(cond, "b i -> b 1 i 1 1 1")
weights = weights * (cond + 1)
if self.demod:
inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt()
weights = weights * inv_norm
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
weights = rearrange(weights, "b o ... -> (b o) ...")
fmap = F.pad(fmap, self.padding, mode=self.pad_mode)
fmap = F.conv3d(fmap, weights, groups=b)
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
# strided conv downsamples
class SpatialDownsample2x(Module):
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2)
def forward(self, x):
x = self.maybe_blur(x, space_only=True)
x = rearrange(x, "b c t h w -> b t c h w")
x, ps = pack_one(x, "* c h w")
out = self.conv(x)
out = unpack_one(out, ps, "* c h w")
out = rearrange(out, "b t c h w -> b c t h w")
return out
class TimeDownsample2x(Module):
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.time_causal_padding = (kernel_size - 1, 0)
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
def forward(self, x):
x = self.maybe_blur(x, time_only=True)
x = rearrange(x, "b c t h w -> b h w c t")
x, ps = pack_one(x, "* c t")
x = F.pad(x, self.time_causal_padding)
out = self.conv(x)
out = unpack_one(out, ps, "* c t")
out = rearrange(out, "b h w c t -> b c t h w")
return out
# depth to space upsamples
class SpatialUpsample2x(Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
x = rearrange(x, "b c t h w -> b t c h w")
x, ps = pack_one(x, "* c h w")
out = self.net(x)
out = unpack_one(out, ps, "* c h w")
out = rearrange(out, "b t c h w -> b c t h w")
return out
class TimeUpsample2x(Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * 2, 1)
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p) t -> b c (t p)", p=2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, t = conv.weight.shape
conv_weight = torch.empty(o // 2, i, t)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
x = rearrange(x, "b c t h w -> b h w c t")
x, ps = pack_one(x, "* c t")
out = self.net(x)
out = unpack_one(out, ps, "* c t")
out = rearrange(out, "b h w c t -> b c t h w")
return out
# autoencoder - only best variant here offered, with causal conv 3d
def SameConv2d(dim_in, dim_out, kernel_size):
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
class CausalConv3d(Module):
@beartype
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = kwargs.pop("stride", 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
return self.conv(x)
@beartype
def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"):
net = Sequential(
CausalConv3d(dim, dim, kernel_size, pad_mode=pad_mode),
nn.ELU(),
nn.Conv3d(dim, dim, 1),
nn.ELU(),
SqueezeExcite(dim),
)
return Residual(net)
@beartype
class ResidualUnitMod(Module):
def __init__(
self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert height_kernel_size == width_kernel_size
self.to_cond = nn.Linear(dim_cond, dim)
self.conv = Conv3DMod(
dim=dim,
spatial_kernel=height_kernel_size,
time_kernel=time_kernel_size,
causal=True,
demod=demod,
pad_mode=pad_mode,
)
self.conv_out = nn.Conv3d(dim, dim, 1)
@beartype
def forward(
self,
x,
cond: Tensor,
):
res = x
cond = self.to_cond(cond)
x = self.conv(x, cond=cond)
x = F.elu(x)
x = self.conv_out(x)
x = F.elu(x)
return x + res
class CausalConvTranspose3d(Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
self.upsample_factor = time_stride
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
stride = (time_stride, 1, 1)
padding = (0, height_pad, width_pad)
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs)
def forward(self, x):
assert x.ndim == 5
t = x.shape[2]
out = self.conv(x)
out = out[..., : (t * self.upsample_factor), :, :]
return out
# video tokenizer class
LossBreakdown = namedtuple(
"LossBreakdown",
[
"recon_loss",
"lfq_aux_loss",
"quantizer_loss_breakdown",
"perceptual_loss",
"adversarial_gen_loss",
"adaptive_adversarial_weight",
"multiscale_gen_losses",
"multiscale_gen_adaptive_weights",
],
)
DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"])
class VideoTokenizer(Module):
@beartype
def __init__(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"),
residual_conv_kernel_size=3,
num_codebooks=1,
codebook_size: Optional[int] = None,
channels=3,
init_dim=64,
max_dim=float("inf"),
dim_cond=None,
dim_cond_expansion_factor=4.0,
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
pad_mode: str = "constant",
lfq_entropy_loss_weight=0.1,
lfq_commitment_loss_weight=1.0,
lfq_diversity_gamma=2.5,
quantizer_aux_loss_weight=1.0,
lfq_activation=nn.Identity(),
use_fsq=False,
fsq_levels: Optional[List[int]] = None,
attn_dim_head=32,
attn_heads=8,
attn_dropout=0.0,
linear_attn_dim_head=8,
linear_attn_heads=16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight=1e-1,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Tuple[Module, ...] = tuple(),
use_gan=True,
adversarial_loss_weight=1.0,
grad_penalty_loss_weight=10.0,
multiscale_adversarial_loss_weight=1.0,
flash_attn=True,
separate_first_frame_encoding=False,
):
super().__init__()
# for autosaving the config
_locals = locals()
_locals.pop("self", None)
_locals.pop("__class__", None)
self._configs = pickle.dumps(_locals)
# image size
self.channels = channels
self.image_size = image_size
# initial encoder
self.conv_in = CausalConv3d(channels, init_dim, input_conv_kernel_size, pad_mode=pad_mode)
# whether to encode the first frame separately or not
self.conv_in_first_frame = nn.Identity()
self.conv_out_first_frame = nn.Identity()
if separate_first_frame_encoding:
self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:])
self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:])
self.separate_first_frame_encoding = separate_first_frame_encoding
# encoder and decoder layers
self.encoder_layers = ModuleList([])
self.decoder_layers = ModuleList([])
self.conv_out = CausalConv3d(init_dim, channels, output_conv_kernel_size, pad_mode=pad_mode)
dim = init_dim
dim_out = dim
layer_fmap_size = image_size
time_downsample_factor = 1
has_cond_across_layers = []
for layer_def in layers:
layer_type, *layer_params = cast_tuple(layer_def)
has_cond = False
if layer_type == "residual":
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
elif layer_type == "consecutive_residual":
(num_consecutive,) = layer_params
encoder_layer = Sequential(
*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
)
decoder_layer = Sequential(
*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
)
elif layer_type == "cond_residual":
assert exists(
dim_cond
), "dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned"
has_cond = True
encoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
)
decoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
)
dim_out = dim
elif layer_type == "compress_space":
dim_out = safe_get_index(layer_params, 0)
dim_out = default(dim_out, dim * 2)
dim_out = min(dim_out, max_dim)
encoder_layer = SpatialDownsample2x(dim, dim_out)
decoder_layer = SpatialUpsample2x(dim_out, dim)
assert layer_fmap_size > 1
layer_fmap_size //= 2
elif layer_type == "compress_time":
dim_out = safe_get_index(layer_params, 0)
dim_out = default(dim_out, dim * 2)
dim_out = min(dim_out, max_dim)
encoder_layer = TimeDownsample2x(dim, dim_out)
decoder_layer = TimeUpsample2x(dim_out, dim)
time_downsample_factor *= 2
elif layer_type == "attend_space":
attn_kwargs = dict(
dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn
)
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
elif layer_type == "linear_attend_space":
linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
)
decoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
)
elif layer_type == "gateloop_time":
gateloop_kwargs = dict(use_heinsen=False)
encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
elif layer_type == "attend_time":
attn_kwargs = dict(
dim=dim,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
causal=True,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
decoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
elif layer_type == "cond_attend_space":
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
)
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
elif layer_type == "cond_linear_attend_space":
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
)
decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
)
elif layer_type == "cond_attend_time":
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
causal=True,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
decoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
else:
raise ValueError(f"unknown layer type {layer_type}")
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out
has_cond_across_layers.append(has_cond)
# add a final norm just before quantization layer
self.encoder_layers.append(
Sequential(
Rearrange("b c ... -> b ... c"),
nn.LayerNorm(dim),
Rearrange("b ... c -> b c ..."),
)
)
self.time_downsample_factor = time_downsample_factor
self.time_padding = time_downsample_factor - 1
self.fmap_size = layer_fmap_size
# use a MLP stem for conditioning, if needed
self.has_cond_across_layers = has_cond_across_layers
self.has_cond = any(has_cond_across_layers)
self.encoder_cond_in = nn.Identity()
self.decoder_cond_in = nn.Identity()
if has_cond:
self.dim_cond = dim_cond
self.encoder_cond_in = Sequential(
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
)
self.decoder_cond_in = Sequential(
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
)
# quantizer related
self.use_fsq = use_fsq
if not use_fsq:
assert exists(codebook_size) and not exists(
fsq_levels
), "if use_fsq is set to False, `codebook_size` must be set (and not `fsq_levels`)"
# lookup free quantizer(s) - multiple codebooks is possible
# each codebook will get its own entropy regularization
self.quantizers = LFQ(
dim=dim,
codebook_size=codebook_size,
num_codebooks=num_codebooks,
entropy_loss_weight=lfq_entropy_loss_weight,
commitment_loss_weight=lfq_commitment_loss_weight,
diversity_gamma=lfq_diversity_gamma,
)
else:
assert (
not exists(codebook_size) and exists(fsq_levels)
), "if use_fsq is set to True, `fsq_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels"
self.quantizers = FSQ(fsq_levels, dim=dim, num_codebooks=num_codebooks)
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
# dummy loss
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
# perceptual loss related
use_vgg = channels in {1, 3, 4} and perceptual_loss_weight > 0.0
self.vgg = None
self.perceptual_loss_weight = perceptual_loss_weight
if use_vgg:
if not exists(vgg):
vgg = torchvision.models.vgg16(weights=vgg_weights)
vgg.classifier = Sequential(*vgg.classifier[:-2])
self.vgg = vgg
self.use_vgg = use_vgg
# main flag for whether to use GAN at all
self.use_gan = use_gan
# discriminator
discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512))
self.discr = Discriminator(**discr_kwargs)
self.adversarial_loss_weight = adversarial_loss_weight
self.grad_penalty_loss_weight = grad_penalty_loss_weight
self.has_gan = use_gan and adversarial_loss_weight > 0.0
# multi-scale discriminators
self.has_multiscale_gan = use_gan and multiscale_adversarial_loss_weight > 0.0
self.multiscale_discrs = ModuleList([*multiscale_discrs])
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
self.has_multiscale_discrs = (
use_gan and multiscale_adversarial_loss_weight > 0.0 and len(multiscale_discrs) > 0
)
@property
def device(self):
return self.zero.device
@classmethod
def init_and_load_from(cls, path, strict=True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location="cpu")
assert "config" in pkg, "model configs were not found in this saved checkpoint"
config = pickle.loads(pkg["config"])
tokenizer = cls(**config)
tokenizer.load(path, strict=strict)
return tokenizer
def parameters(self):
return [
*self.conv_in.parameters(),
*self.conv_in_first_frame.parameters(),
*self.conv_out_first_frame.parameters(),
*self.conv_out.parameters(),
*self.encoder_layers.parameters(),
*self.decoder_layers.parameters(),
*self.encoder_cond_in.parameters(),
*self.decoder_cond_in.parameters(),
*self.quantizers.parameters(),
]
def discr_parameters(self):
return self.discr.parameters()
def copy_for_eval(self):
device = self.device
vae_copy = copy.deepcopy(self.cpu())
maybe_del_attr_(vae_copy, "discr")
maybe_del_attr_(vae_copy, "vgg")
maybe_del_attr_(vae_copy, "multiscale_discrs")
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
def save(self, path, overwrite=True):
path = Path(path)
assert overwrite or not path.exists(), f"{str(path)} already exists"
pkg = dict(model_state_dict=self.state_dict(), version=__version__, config=self._configs)
torch.save(pkg, str(path))
def load(self, path, strict=True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
state_dict = pkg.get("model_state_dict")
version = pkg.get("version")
assert exists(state_dict)
if exists(version):
print(f"loading checkpointed tokenizer from version {version}")
self.load_state_dict(state_dict, strict=strict)
@beartype
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# whether to pad video or not
if video_contains_first_frame:
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond):
assert cond.shape == (video.shape[0], self.dim_cond)
cond = self.encoder_cond_in(cond)
cond_kwargs = dict(cond=cond)
# initial conv
# taking into account whether to encode first frame separately
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w")
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
if encode_first_frame_separately:
video, _ = pack([first_frame, video], "b c * h w")
video = pad_at_dim(video, (self.time_padding, 0), dim=2)
# encoder layers
for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
video = fn(video, **layer_kwargs)
maybe_quantize = identity if not quantize else self.quantizers
return maybe_quantize(video)
@beartype
def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
assert codes.dtype in (torch.long, torch.int32)
if codes.ndim == 2:
video_code_len = codes.shape[-1]
assert divisible_by(
video_code_len, self.fmap_size**2
), f"flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})"
codes = rearrange(codes, "b (f h w) -> b f h w", h=self.fmap_size, w=self.fmap_size)
quantized = self.quantizers.indices_to_codes(codes)
return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
@beartype
def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
batch = quantized.shape[0]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond):
assert cond.shape == (batch, self.dim_cond)
cond = self.decoder_cond_in(cond)
cond_kwargs = dict(cond=cond)
# decoder layers
x = quantized
for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
x = fn(x, **layer_kwargs)
# to pixels
if decode_first_frame_separately:
left_pad, xff, x = (
x[:, :, : self.time_padding],
x[:, :, self.time_padding],
x[:, :, (self.time_padding + 1) :],
)
out = self.conv_out(x)
outff = self.conv_out_first_frame(xff)
video, _ = pack([outff, out], "b c * h w")
else:
video = self.conv_out(x)
# if video were padded, remove padding
if video_contains_first_frame:
video = video[:, :, self.time_padding :]
return video
@torch.no_grad()
def tokenize(self, video):
self.eval()
return self.forward(video, return_codes=True)
@beartype
def forward(
self,
video_or_images: Tensor,
cond: Optional[Tensor] = None,
return_loss=False,
return_codes=False,
return_recon=False,
return_discr_loss=False,
return_recon_loss_only=False,
apply_gradient_penalty=True,
video_contains_first_frame=True,
adversarial_loss_weight=None,
multiscale_adversarial_loss_weight=None,
):
adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight)
multiscale_adversarial_loss_weight = default(
multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight
)
assert (return_loss + return_codes + return_discr_loss) <= 1
assert video_or_images.ndim in {4, 5}
assert video_or_images.shape[-2:] == (self.image_size, self.image_size)
# accept images for image pretraining (curriculum learning from images to video)
is_image = video_or_images.ndim == 4
if is_image:
video = rearrange(video_or_images, "b c ... -> b c 1 ...")
video_contains_first_frame = True
else:
video = video_or_images
batch, channels, frames = video.shape[:3]
assert divisible_by(
frames - int(video_contains_first_frame), self.time_downsample_factor
), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}"
# encoder
x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame)
# lookup free quantization
if self.use_fsq:
quantized, codes = self.quantizers(x)
aux_losses = self.zero
quantizer_loss_breakdown = None
else:
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
if return_codes and not return_recon:
return codes
# decoder
recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
if return_codes:
return codes, recon_video
# reconstruction loss
if not (return_loss or return_discr_loss or return_recon_loss_only):
return recon_video
recon_loss = F.mse_loss(video, recon_video)
# for validation, only return recon loss
if return_recon_loss_only:
return recon_loss, recon_video
# gan discriminator loss
if return_discr_loss:
assert self.has_gan
assert exists(self.discr)
# pick a random frame for image discriminator
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
real = pick_video_frame(video, frame_indices)
if apply_gradient_penalty:
real = real.requires_grad_()
fake = pick_video_frame(recon_video, frame_indices)
real_logits = self.discr(real)
fake_logits = self.discr(fake.detach())
discr_loss = hinge_discr_loss(fake_logits, real_logits)
# multiscale discriminators
multiscale_discr_losses = []
if self.has_multiscale_discrs:
for discr in self.multiscale_discrs:
multiscale_real_logits = discr(video)
multiscale_fake_logits = discr(recon_video.detach())
multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
multiscale_discr_losses.append(multiscale_discr_loss)
else:
multiscale_discr_losses.append(self.zero)
# gradient penalty
if apply_gradient_penalty:
gradient_penalty_loss = gradient_penalty(real, real_logits)
else:
gradient_penalty_loss = self.zero
# total loss
total_loss = (
discr_loss
+ gradient_penalty_loss * self.grad_penalty_loss_weight
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
)
discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss)
return total_loss, discr_loss_breakdown
# perceptual loss
if self.use_vgg:
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
input_vgg_input = pick_video_frame(video, frame_indices)
recon_vgg_input = pick_video_frame(recon_video, frame_indices)
if channels == 1:
input_vgg_input = repeat(input_vgg_input, "b 1 h w -> b c h w", c=3)
recon_vgg_input = repeat(recon_vgg_input, "b 1 h w -> b c h w", c=3)
elif channels == 4:
input_vgg_input = input_vgg_input[:, :3]
recon_vgg_input = recon_vgg_input[:, :3]
input_vgg_feats = self.vgg(input_vgg_input)
recon_vgg_feats = self.vgg(recon_vgg_input)
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
else:
perceptual_loss = self.zero
# get gradient with respect to perceptual loss for last decoder layer
# needed for adaptive weighting
last_dec_layer = self.conv_out.conv.weight
norm_grad_wrt_perceptual_loss = None
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
# per-frame image discriminator
recon_video_frames = None
if self.has_gan:
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
recon_video_frames = pick_video_frame(recon_video, frame_indices)
fake_logits = self.discr(recon_video_frames)
gen_loss = hinge_gen_loss(fake_logits)
adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():
adaptive_weight = 1.0
else:
gen_loss = self.zero
adaptive_weight = 0.0
# multiscale discriminator losses
multiscale_gen_losses = []
multiscale_gen_adaptive_weights = []
if self.has_multiscale_gan and self.has_multiscale_discrs:
if not exists(recon_video_frames):
recon_video_frames = pick_video_frame(recon_video, frame_indices)
for discr in self.multiscale_discrs:
fake_logits = recon_video_frames
multiscale_gen_loss = hinge_gen_loss(fake_logits)
multiscale_gen_losses.append(multiscale_gen_loss)
multiscale_adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2)
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
multiscale_adaptive_weight.clamp_(max=1e3)
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
# calculate total loss
total_loss = (
recon_loss
+ aux_losses * self.quantizer_aux_loss_weight
+ perceptual_loss * self.perceptual_loss_weight
+ gen_loss * adaptive_weight * adversarial_loss_weight
)
if self.has_multiscale_discrs:
weighted_multiscale_gen_losses = sum(
loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
)
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
# loss breakdown
loss_breakdown = LossBreakdown(
recon_loss,
aux_losses,
quantizer_loss_breakdown,
perceptual_loss,
gen_loss,
adaptive_weight,
multiscale_gen_losses,
multiscale_gen_adaptive_weights,
)
return total_loss, loss_breakdown
# main class
class MagViT2(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import DiagonalGaussianDistribution
from .base import AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn.functional as F
from torch import nn
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast
from einops import rearrange, pack, unpack
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# tensor helpers
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
# main class
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent=False)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, "... -> ... 1")
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, "... c d -> ... (c d)")
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, "b ... d -> b d ...")
return codes
@autocast(enabled=False)
def forward(self, z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
is_img_or_video = z.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
z = rearrange(z, "b d ... -> b ... d")
z, ps = pack_one(z, "b * d")
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
z = self.project_in(z)
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, "b n c d -> b n (c d)")
out = self.project_out(codes)
# reconstitute image or video dimensions
if is_img_or_video:
out = unpack_one(out, ps, "b * d")
out = rearrange(out, "b ... d -> b d ...")
indices = unpack_one(indices, ps, "b * c")
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, "... 1 -> ...")
return out, indices
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
from math import log2, ceil
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.cuda.amp import autocast
from einops import rearrange, reduce, pack, unpack
# constants
Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"])
LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"])
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# entropy
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# class
class LFQ(Module):
def __init__(
self,
*,
dim=None,
codebook_size=None,
entropy_loss_weight=0.1,
commitment_loss_weight=0.25,
diversity_gamma=1.0,
straight_through_activation=nn.Identity(),
num_codebooks=1,
keep_num_codebooks_dim=None,
codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
):
super().__init__()
# some assert validations
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
assert (
not exists(codebook_size) or log2(codebook_size).is_integer()
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
codebook_size = default(codebook_size, lambda: 2**dim)
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
# straight through activation
self.activation = straight_through_activation
# entropy aux loss related weights
assert 0 < frac_per_sample_entropy <= 1.0
self.frac_per_sample_entropy = frac_per_sample_entropy
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
# codebook scale
self.codebook_scale = codebook_scale
# commitment loss
self.commitment_loss_weight = commitment_loss_weight
# for no auxiliary loss, during inference
self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
# codes
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer("codebook", codebook, persistent=False)
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self):
return self.codebook.dtype
def indices_to_codes(self, indices, project_out=True):
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, "... -> ... 1")
# indices to codes, which are bits of either -1 or 1
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = self.bits_to_codes(bits)
codes = rearrange(codes, "... c d -> ... (c d)")
# whether to project codes out to original dimensions
# if the input feature dimensions were not log2(codebook size)
if project_out:
codes = self.project_out(codes)
# rearrange codes back to original shape
if is_img_or_video:
codes = rearrange(codes, "b ... d -> b d ...")
return codes
@autocast(enabled=False)
def forward(
self,
x,
inv_temperature=100.0,
return_loss_breakdown=False,
mask=None,
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
x = x.float()
is_img_or_video = x.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
x = rearrange(x, "b d ... -> b ... d")
x, ps = pack_one(x, "b * d")
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
x = self.project_in(x)
# split out number of codebooks
x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
# quantize by eq 3.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)
# use straight-through gradients (optionally with custom activation fn) if training
if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized
# calculate indices
indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
# entropy aux loss
if self.training:
# the same as euclidean distance up to a constant
distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook)
prob = (-distance * inv_temperature).softmax(dim=-1)
# account for mask
if exists(mask):
prob = prob[mask]
else:
prob = rearrange(prob, "b n ... -> (b n) ...")
# whether to only use a fraction of probs, for reducing memory
if self.frac_per_sample_entropy < 1.0:
num_tokens = prob.shape[0]
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens
per_sample_probs = prob[rand_mask]
else:
per_sample_probs = prob
# calculate per sample entropy
per_sample_entropy = entropy(per_sample_probs).mean()
# distribution over all available tokens in the batch
avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean")
codebook_entropy = entropy(avg_prob).mean()
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
# commit loss
if self.training:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none")
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# merge back codebook dim
x = rearrange(x, "b n c d -> b n (c d)")
# project out to feature dimension if needed
x = self.project_out(x)
# reconstitute image or video dimensions
if is_img_or_video:
x = unpack_one(x, ps, "b * d")
x = rearrange(x, "b ... d -> b d ...")
indices = unpack_one(indices, ps, "b * c")
# whether to remove single codebook dim
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, "... 1 -> ...")
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
ret = Return(x, indices, aux_loss)
if not return_loss_breakdown:
return ret
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
import logging
from abc import abstractmethod
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from .base import AbstractRegularizer, measure_perplexity
logpy = logging.getLogger(__name__)
class AbstractQuantizer(AbstractRegularizer):
def __init__(self):
super().__init__()
# Define these in your init
# shape (N,)
self.used: Optional[torch.Tensor]
self.re_embed: int
self.unknown_index: Union[Literal["random"], int]
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
@abstractmethod
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
yield from self.parameters()
class GumbelQuantizer(AbstractQuantizer):
"""
credit to @karpathy:
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens: int,
embedding_dim: int,
n_embed: int,
straight_through: bool = True,
kl_weight: float = 5e-4,
temp_init: float = 1.0,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
) -> None:
super().__init__()
self.loss_key = loss_key
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
) -> Tuple[torch.Tensor, Dict]:
# force hard = True when we are in eval mode, as we must quantize.
# actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
out_dict = {}
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
out_dict[self.loss_key] = diff
ind = soft_one_hot.argmax(dim=1)
out_dict["indices"] = ind
if self.remap is not None:
ind = self.remap_to_used(ind)
if return_logits:
out_dict["logits"] = logits
return z_q, out_dict
def get_codebook_entry(self, indices, shape):
# TODO: shape not yet optional
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
return z_q
class VectorQuantizer(AbstractQuantizer):
"""
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term,
beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(
self,
n_e: int,
e_dim: int,
beta: float = 0.25,
remap: Optional[str] = None,
unknown_index: str = "random",
sane_index_shape: bool = False,
log_perplexity: bool = False,
embedding_weight_norm: bool = False,
loss_key: str = "loss/vq",
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.loss_key = loss_key
if not embedding_weight_norm:
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
else:
self.embedding = torch.nn.utils.weight_norm(nn.Embedding(self.n_e, self.e_dim), dim=1)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_e
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
self.sane_index_shape = sane_index_shape
self.log_perplexity = log_perplexity
def forward(
self,
z: torch.Tensor,
) -> Tuple[torch.Tensor, Dict]:
do_reshape = z.ndim == 4
if do_reshape:
# # reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
else:
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
z = z.contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss_dict = {}
if self.log_perplexity:
perplexity, cluster_usage = measure_perplexity(min_encoding_indices.detach(), self.n_e)
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss_dict[self.loss_key] = loss
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
if do_reshape:
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
else:
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
loss_dict["min_encoding_indices"] = min_encoding_indices
return z_q, loss_dict
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
assert shape is not None, "Need to give shape for remap"
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
super().__init__()
self.decay = decay
self.eps = eps
weight = torch.randn(num_tokens, codebook_dim)
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
self.update = True
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
self.weight.data.copy_(embed_normalized)
class EMAVectorQuantizer(AbstractQuantizer):
def __init__(
self,
n_embed: int,
embedding_dim: int,
beta: float,
decay: float = 0.99,
eps: float = 1e-5,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.loss_key = loss_key
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
z = rearrange(z, "b c h w -> b h w c")
z_flattened = z.reshape(-1, self.codebook_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
z_flattened.pow(2).sum(dim=1, keepdim=True)
+ self.embedding.weight.pow(2).sum(dim=1)
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
if self.training and self.embedding.update:
# EMA cluster size
encodings_sum = encodings.sum(0)
self.embedding.cluster_size_ema_update(encodings_sum)
# EMA embedding average
embed_sum = encodings.transpose(0, 1) @ z_flattened
self.embedding.embed_avg_ema_update(embed_sum)
# normalize embed_avg and update weight
self.embedding.weight_update(self.num_tokens)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, "b h w c -> b c h w")
out_dict = {
self.loss_key: loss,
"encodings": encodings,
"encoding_indices": encoding_indices,
"perplexity": perplexity,
}
return z_q, out_dict
class VectorQuantizerWithInputProjection(VectorQuantizer):
def __init__(
self,
input_dim: int,
n_codes: int,
codebook_dim: int,
beta: float = 1.0,
output_dim: Optional[int] = None,
**kwargs,
):
super().__init__(n_codes, codebook_dim, beta, **kwargs)
self.proj_in = nn.Linear(input_dim, codebook_dim)
self.output_dim = output_dim
if output_dim is not None:
self.proj_out = nn.Linear(codebook_dim, output_dim)
else:
self.proj_out = nn.Identity()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
rearr = False
in_shape = z.shape
if z.ndim > 3:
rearr = self.output_dim is not None
z = rearrange(z, "b c ... -> b (...) c")
z = self.proj_in(z)
z_q, loss_dict = super().forward(z)
z_q = self.proj_out(z_q)
if rearr:
if len(in_shape) == 4:
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
elif len(in_shape) == 5:
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
else:
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
return z_q, loss_dict
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import (
XFORMERS_IS_AVAILABLE,
AttnBlock,
Decoder,
MemoryEfficientAttnBlock,
ResnetBlock,
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
if timesteps is None:
timesteps = self.timesteps
b, c, h, w = x.shape
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps, skip_video=False):
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class VideoBlock(AttnBlock):
def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_video=False):
if skip_video:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax-xformers",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_time_block=False):
if skip_time_block:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
assert attn_type in [
"vanilla",
"vanilla-xformers",
], f"attn_type {attn_type} not supported for spatio-temporal attention"
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_type = "vanilla"
if attn_type == "vanilla":
assert attn_kwargs is None
return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return partialclass(
MemoryEfficientVideoBlock,
in_channels,
alpha=alpha,
merge_strategy=merge_strategy,
)
else:
return NotImplementedError()
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
def _make_attn(self) -> Callable:
if self.time_mode not in ["conv-only", "only-last-conv"]:
return partialclass(
make_time_attn,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_attn()
def _make_conv(self) -> Callable:
if self.time_mode != "attn-only":
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
else:
return Conv2DWrapper
def _make_resblock(self) -> Callable:
if self.time_mode not in ["attn-only", "only-last-conv"]:
return partialclass(
VideoResBlock,
video_kernel_size=self.video_kernel_size,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_resblock()
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, f, zq):
if zq.shape[2] > 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(in_channels, zq_ch, add_conv):
return SpatialNorm3D(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
pad_mode="constant",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
t = h_.shape[2]
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
return x + h_
class MOVQDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
def forward(self, z, use_cp=False):
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
class NewDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
post_quant_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
if post_quant_conv:
self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.post_quant_conv = None
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
# self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
# self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
# self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, f, zq):
if zq.shape[2] > 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(in_channels, zq_ch, add_conv):
return SpatialNorm3D(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
pad_mode="constant",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
# self.conv1 = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv3d(out_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
t = h_.shape[2]
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
return x + h_
class MOVQDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
def forward(self, z, use_cp=False):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
class NewDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
post_quant_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
if post_quant_conv:
self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.post_quant_conv = None
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class CausalConv3d(nn.Module):
@beartype
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = kwargs.pop("stride", 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.pad_mode == "constant":
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
elif self.pad_mode == "first":
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
x = torch.cat([pad_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
x = F.pad(x, causal_padding_2d, mode="constant", value=0)
elif self.pad_mode == "reflect":
# reflect padding
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
if reflect_x.shape[2] < self.time_pad:
reflect_x = torch.cat(
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
)
x = torch.cat([reflect_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
x = F.pad(x, causal_padding_2d, mode="constant", value=0)
else:
raise ValueError("Invalid pad mode")
return self.conv(x)
def Normalize3D(in_channels): # same for 3D and 2D
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] > 1:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = x.squeeze(2)
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = x[:, :, None, :, :]
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class DownSample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class ResnetBlock3D(nn.Module):
def __init__(
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels)
# self.conv1 = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels)
self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv3d(out_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
t = h_.shape[2]
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
# # original version, nan in fp16
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c)**(-0.5))
# # implement c**-0.5 on q
q = q * (int(c) ** (-0.5))
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
return x + h_
class Encoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
# downsampling
# self.conv_in = torch.nn.Conv3d(in_channels,
# self.ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
else:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
)
# end
self.norm_out = Normalize3D(block_in)
# self.conv_out = torch.nn.Conv3d(block_in,
# 2*z_channels if double_z else z_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode
)
def forward(self, x, use_cp=False):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f, zq):
f_size = f.shape[-2:]
zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize(in_channels, zq_ch, add_conv):
return SpatialNorm(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class MOVQDecoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z, zq):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def forward_with_features_output(self, z, zq):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
output_features = {}
# z to block_in
h = self.conv_in(z)
output_features["conv_in"] = h
# middle
h = self.mid.block_1(h, temb, zq)
output_features["mid_block_1"] = h
h = self.mid.attn_1(h, zq)
output_features["mid_attn_1"] = h
h = self.mid.block_2(h, temb, zq)
output_features["mid_block_2"] = h
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
output_features[f"up_{i_level}_block_{i_block}"] = h
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
output_features[f"up_{i_level}_attn_{i_block}"] = h
if i_level != 0:
h = self.up[i_level].upsample(h)
output_features[f"up_{i_level}_upsample"] = h
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
output_features["norm_out"] = h
h = nonlinearity(h)
output_features["nonlinearity"] = h
h = self.conv_out(h)
output_features["conv_out"] = h
return h, output_features
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits == False, "Only for interface compatible with Gumbel"
assert return_logits == False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantize(nn.Module):
"""
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens,
embedding_dim,
n_embed,
straight_through=True,
kl_weight=5e-4,
temp_init=1.0,
use_vqinterface=True,
remap=None,
unknown_index="random",
):
super().__init__()
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.use_vqinterface = use_vqinterface
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
self.re_embed = n_embed
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, return_logits=False):
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
ind = soft_one_hot.argmax(dim=1)
if self.remap is not None:
ind = self.remap_to_used(ind)
if self.use_vqinterface:
if return_logits:
return z_q, diff, (None, None, ind), logits
return z_q, diff, (None, None, ind)
return z_q, diff, ind
def get_codebook_entry(self, indices, shape):
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
return z_q
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
# # original version, nan in fp16
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c)**(-0.5))
# # implement c**-0.5 on q
q = q * (int(c) ** (-0.5))
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def forward_with_features_output(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
output_features = {}
# downsampling
hs = [self.conv_in(x)]
output_features["conv_in"] = hs[-1]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
output_features["down{}_block{}".format(i_level, i_block)] = h
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
output_features["down{}_attn{}".format(i_level, i_block)] = h
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
output_features["down{}_downsample".format(i_level)] = hs[-1]
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
output_features["mid_block_1"] = h
h = self.mid.attn_1(h)
output_features["mid_attn_1"] = h
h = self.mid.block_2(h, temb)
output_features["mid_block_2"] = h
# end
h = self.norm_out(h)
output_features["norm_out"] = h
h = nonlinearity(h)
output_features["nonlinearity"] = h
h = self.conv_out(h)
output_features["conv_out"] = h
return h, output_features
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
import math
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from ..util import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_group_rank,
)
# try:
from ..util import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
_USE_CP = True
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def _split(input_, dim):
cp_world_size = get_context_parallel_world_size()
if cp_world_size == 1:
return input_
cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
input_list = torch.split(input_, dim_size, dim=dim)
output = input_list[cp_rank]
if cp_rank == 0:
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _gather(input_, dim):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_frame_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
dim, 0
)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_split(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_gather(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
def conv_pass_from_last_rank(input_, dim, kernel_size):
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
class ContextParallelCausalConv3d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
time_pad = time_kernel_size - 1
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_kernel_size = time_kernel_size
self.temporal_dim = 2
stride = (stride, stride, stride)
dilation = (1, 1, 1)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input_):
# temporal padding inside
if _USE_CP:
input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
else:
input_ = input_.transpose(0, self.temporal_dim)
input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0)
input_parallel = input_parallel.transpose(0, self.temporal_dim)
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class ContextParallelGroupNorm(torch.nn.GroupNorm):
def forward(self, input_):
if _USE_CP:
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
output = super().forward(input_)
if _USE_CP:
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
gather=False,
**norm_layer_params,
):
super().__init__()
if gather:
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if add_conv:
self.conv = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=zq_channels,
kernel_size=3,
)
self.conv_y = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
self.conv_b = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
def forward(self, f, zq):
if f.shape[2] == 1 and not _USE_CP:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
elif get_context_parallel_rank() == 0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(
in_channels,
zq_ch,
add_conv,
gather=False,
):
return SpatialNorm3D(
in_channels,
zq_ch,
gather=gather,
# norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample3D(nn.Module):
def __init__(
self,
in_channels,
with_conv,
compress_time=False,
):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] == 1 and not _USE_CP:
x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :]
elif get_context_parallel_rank() == 0:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class DownSample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class ContextParallelResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=Normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalization(
in_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.conv1 = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ContextParallelCausalConv3d(
chan_in=out_channels,
chan_out=out_channels,
kernel_size=3,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
else:
self.nin_shortcut = Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x, temb, zq=None):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class ContextParallelEncoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=self.ch,
kernel_size=3,
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
gather_norm=gather_norm,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
else:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
# end
self.norm_out = Normalize(block_in, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=2 * z_channels if double_z else z_channels,
kernel_size=3,
)
def forward(self, x, use_cp=True):
global _USE_CP
_USE_CP = use_cp
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class ContextParallelDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
chan_out=block_in,
kernel_size=3,
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=out_ch,
kernel_size=3,
)
def forward(self, z, use_cp=True):
global _USE_CP
_USE_CP = use_cp
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
_USE_CP = True
return h
def get_last_layer(self):
return self.conv_out.conv.weight
from .denoiser import Denoiser
from .discretizer import Discretization
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment