Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
from ..util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name='vgg_lpips'):
ckpt = get_ckpt_path(name, 'sgm/modules/autoencoding/lpips/loss')
self.load_state_dict(torch.load(ckpt,
map_location=torch.device('cpu')),
strict=False)
print(f'loaded pretrained LPIPS loss from {ckpt}')
@classmethod
def from_pretrained(cls, name='vgg_lpips'):
if name != 'vgg_lpips':
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt,
map_location=torch.device('cpu')),
strict=False)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input),
self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(
outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk])**2
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
'shift',
torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
self.register_buffer(
'scale',
torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super().__init__()
layers = ([
nn.Dropout(),
] if (use_dropout) else [])
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super().__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
'VggOutputs',
['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3,
h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import functools
import torch.nn as nn
from ..util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
try:
nn.init.normal_(m.weight.data, 0.0, 0.02)
except:
nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super().__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(
norm_layer
) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1,
n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
import hashlib
import os
import requests
import torch
import torch.nn as nn
from tqdm import tqdm
URL_MAP = {
'vgg_lpips':
'https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1'
}
CKPT_MAP = {'vgg_lpips': 'vgg.pth'}
MD5_MAP = {'vgg_lpips': 'd507d7349b931f0638a25a48a722f98a'}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get('content-length', 0))
with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
with open(local_path, 'wb') as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, 'rb') as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check
and not md5_hash(path) == MD5_MAP[name]):
print('Downloading {} model from {} to {}'.format(
name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class ActNorm(nn.Module):
def __init__(self,
num_features,
logdet=False,
affine=True,
allow_reverse_init=False):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2,
3).contiguous().view(input.shape[1], -1)
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(
3).permute(1, 0, 2, 3)
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(
3).permute(1, 0, 2, 3)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
'Initializing ActNorm in reverse direction is '
'disabled by default. Use allow_reverse_init=True to enable.'
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:, :, None, None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
import torch
import torch.nn.functional as F
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
import copy
import pickle
from collections import namedtuple
from functools import partial, wraps
from math import ceil, log2, sqrt
from pathlib import Path
import torch
import torch.nn.functional as F
import torchvision
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from einops import pack, rearrange, reduce, repeat, unpack
from einops.layers.torch import Rearrange
from gateloop_transformer import SimpleGateLoopLayer
from kornia.filters import filter3d
from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__
from taylor_series_linear_attention import TaylorSeriesLinearAttn
from torch import Tensor, einsum, nn
from torch.autograd import grad as torch_grad
from torch.cuda.amp import autocast
from torch.nn import Module, ModuleList
from torchvision.models import VGG16_Weights
# from vector_quantize_pytorch import LFQ, FSQ
from .regularizers.finite_scalar_quantization import FSQ
from .regularizers.lookup_free_quantization import LFQ
# helper
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def safe_get_index(it, ind, default=None):
if ind < len(it):
return it[ind]
return default
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def identity(t, *args, **kwargs):
return t
def divisible_by(num, den):
return (num % den) == 0
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def append_dims(t, ndims: int):
return t.reshape(*t.shape, *((1, ) * ndims))
def is_odd(n):
return not divisible_by(n, 2)
def maybe_del_attr_(o, attr):
if hasattr(o, attr):
delattr(o, attr)
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
# tensor helpers
def l2norm(t):
return F.normalize(t, dim=-1, p=2)
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
def pick_video_frame(video, frame_indices):
batch, device = video.shape[0], video.device
video = rearrange(video, 'b c f ... -> b f c ...')
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
images = video[batch_indices, frame_indices]
images = rearrange(images, 'b 1 c ... -> b c ...')
return images
# gan related
def gradient_penalty(images, output):
batch_size = images.shape[0]
gradients = torch_grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return ((gradients.norm(2, dim=1) - 1)**2).mean()
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss,
inputs=layer,
grad_outputs=torch.ones_like(loss),
retain_graph=True)[0].detach()
# helper decorators
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# helper classes
def Sequential(*modules):
modules = [*filter(exists, modules)]
if len(modules) == 0:
return nn.Identity()
return nn.Sequential(*modules)
class Residual(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back
class ToTimeSequence(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
x = rearrange(x, 'b c f ... -> b ... f c')
x, ps = pack_one(x, '* n c')
o = self.fn(x, **kwargs)
o = unpack_one(o, ps, '* n c')
return rearrange(o, 'b ... f c -> b c f ...')
class SqueezeExcite(Module):
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
super().__init__()
dim_out = default(dim_out, dim)
self.to_k = nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)
self.net = nn.Sequential(nn.Conv2d(dim, dim_hidden, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_hidden, dim_out, 1),
nn.Sigmoid())
nn.init.zeros_(self.net[-2].weight)
nn.init.constant_(self.net[-2].bias, init_bias)
def forward(self, x):
orig_input, batch = x, x.shape[0]
is_video = x.ndim == 5
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
context = self.to_k(x)
context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim=-1)
spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')
out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
out = rearrange(out, '... -> ... 1')
gates = self.net(out)
if is_video:
gates = rearrange(gates, '(b f) c h w -> b c f h w', b=batch)
return gates * orig_input
# token shifting
class TokenShift(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
x, x_shift = x.chunk(2, dim=1)
x_shift = pad_at_dim(x_shift, (1, -1), dim=2) # shift time dimension
x = torch.cat((x, x_shift), dim=1)
return self.fn(x, **kwargs)
# rmsnorm
class RMSNorm(Module):
def __init__(self, dim, channel_first=False, images=False, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim, )
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class AdaptiveRMSNorm(Module):
def __init__(self,
dim,
*,
dim_cond,
channel_first=False,
images=False,
bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim, )
self.dim_cond = dim_cond
self.channel_first = channel_first
self.scale = dim**0.5
self.to_gamma = nn.Linear(dim_cond, dim)
self.to_bias = nn.Linear(dim_cond, dim) if bias else None
nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)
if bias:
nn.init.zeros_(self.to_bias.weight)
nn.init.zeros_(self.to_bias.bias)
@beartype
def forward(self, x: Tensor, *, cond: Tensor):
batch = x.shape[0]
assert cond.shape == (batch, self.dim_cond)
gamma = self.to_gamma(cond)
bias = 0.0
if exists(self.to_bias):
bias = self.to_bias(cond)
if self.channel_first:
gamma = append_dims(gamma, x.ndim - 2)
if exists(self.to_bias):
bias = append_dims(bias, x.ndim - 2)
return F.normalize(x, dim=(1 if self.channel_first else
-1)) * self.scale * gamma + bias
# attention
class Attention(Module):
@beartype
def __init__(
self,
*,
dim,
dim_cond: Optional[int] = None,
causal=False,
dim_head=32,
heads=8,
flash=False,
dropout=0.0,
num_memory_kv=4,
):
super().__init__()
dim_inner = dim_head * heads
self.need_cond = exists(dim_cond)
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
else:
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=heads))
assert num_memory_kv > 0
self.mem_kv = nn.Parameter(
torch.randn(2, heads, num_memory_kv, dim_head))
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False))
@beartype
def forward(self,
x,
mask: Optional[Tensor] = None,
cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
x = self.norm(x, **maybe_cond_kwargs)
q, k, v = self.to_qkv(x)
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b=q.shape[0]),
self.mem_kv)
k = torch.cat((mk, k), dim=-2)
v = torch.cat((mv, v), dim=-2)
out = self.attend(q, k, v, mask=mask)
return self.to_out(out)
class LinearAttention(Module):
"""
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
"""
@beartype
def __init__(self,
*,
dim,
dim_cond: Optional[int] = None,
dim_head=8,
heads=8,
dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.need_cond = exists(dim_cond)
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
else:
self.norm = RMSNorm(dim)
self.attn = TaylorSeriesLinearAttn(dim=dim,
dim_head=dim_head,
heads=heads)
def forward(self, x, cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
x = self.norm(x, **maybe_cond_kwargs)
return self.attn(x)
class LinearSpaceAttention(LinearAttention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, 'b c ... h w -> b ... h w c')
x, batch_ps = pack_one(x, '* h w c')
x, seq_ps = pack_one(x, 'b * c')
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, 'b * c')
x = unpack_one(x, batch_ps, '* h w c')
return rearrange(x, 'b ... h w c -> b c ... h w')
class SpaceAttention(Attention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, 'b c t h w -> b t h w c')
x, batch_ps = pack_one(x, '* h w c')
x, seq_ps = pack_one(x, 'b * c')
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, 'b * c')
x = unpack_one(x, batch_ps, '* h w c')
return rearrange(x, 'b t h w c -> b c t h w')
class TimeAttention(Attention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, 'b c t h w -> b h w t c')
x, batch_ps = pack_one(x, '* t c')
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, batch_ps, '* t c')
return rearrange(x, 'b h w t c -> b c t h w')
class GEGLU(Module):
def forward(self, x):
x, gate = x.chunk(2, dim=1)
return F.gelu(gate) * x
class FeedForward(Module):
@beartype
def __init__(self,
dim,
*,
dim_cond: Optional[int] = None,
mult=4,
images=False):
super().__init__()
conv_klass = nn.Conv2d if images else nn.Conv3d
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(
AdaptiveRMSNorm, dim_cond=dim_cond)
maybe_adaptive_norm_klass = partial(rmsnorm_klass,
channel_first=True,
images=images)
dim_inner = int(dim * mult * 2 / 3)
self.norm = maybe_adaptive_norm_klass(dim)
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(),
conv_klass(dim_inner, dim, 1))
@beartype
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
maybe_cond_kwargs = dict(cond=cond) if exists(cond) else dict()
x = self.norm(x, **maybe_cond_kwargs)
return self.net(x)
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class Blur(Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x, space_only=False, time_only=False):
assert not (space_only and time_only)
f = self.f
if space_only:
f = einsum('i, j -> i j', f, f)
f = rearrange(f, '... -> 1 1 ...')
elif time_only:
f = rearrange(f, 'f -> 1 f 1 1')
else:
f = einsum('i, j, k -> i j k', f, f, f)
f = rearrange(f, '... -> 1 ...')
is_images = x.ndim == 4
if is_images:
x = rearrange(x, 'b c h w -> b c 1 h w')
out = filter3d(x, f, normalized=True)
if is_images:
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
class DiscriminatorBlock(Module):
def __init__(self,
input_channels,
filters,
downsample=True,
antialiased_downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels,
filters,
1,
stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = (nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1)) if downsample else None)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator(Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
channels=3,
max_dim=512,
attn_heads=8,
attn_dim_head=32,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
blocks = []
layer_dims = [channels] + [(dim * 4) * (2**i)
for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
attn_blocks = []
image_resolution = min_image_resolution
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample)
attn_block = Sequential(
Residual(
LinearSpaceAttention(dim=out_chan,
heads=linear_attn_heads,
dim_head=linear_attn_dim_head)),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(
map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b'),
)
def forward(self, x):
for block, attn_block in self.blocks:
x = block(x)
x = attn_block(x)
return self.to_logits(x)
# modulatable conv from Karras et al. Stylegan2
# for conditioning on latents
class Conv3DMod(Module):
@beartype
def __init__(self,
dim,
*,
spatial_kernel,
time_kernel,
causal=True,
dim_out=None,
demod=True,
eps=1e-8,
pad_mode='zeros'):
super().__init__()
dim_out = default(dim_out, dim)
self.eps = eps
assert is_odd(spatial_kernel) and is_odd(time_kernel)
self.spatial_kernel = spatial_kernel
self.time_kernel = time_kernel
time_padding = (time_kernel - 1,
0) if causal else ((time_kernel // 2, ) * 2)
self.pad_mode = pad_mode
self.padding = (*((spatial_kernel // 2, ) * 4), *time_padding)
self.weights = nn.Parameter(
torch.randn(
(dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
self.demod = demod
nn.init.kaiming_normal_(self.weights,
a=0,
mode='fan_in',
nonlinearity='selu')
@beartype
def forward(self, fmap, cond: Tensor):
"""
notation
b - batch
n - convs
o - output
i - input
k - kernel
"""
b = fmap.shape[0]
# prepare weights for modulation
weights = self.weights
# do the modulation, demodulation, as done in stylegan2
cond = rearrange(cond, 'b i -> b 1 i 1 1 1')
weights = weights * (cond + 1)
if self.demod:
inv_norm = reduce(weights**2, 'b o i k0 k1 k2 -> b o 1 1 1 1',
'sum').clamp(min=self.eps).rsqrt()
weights = weights * inv_norm
fmap = rearrange(fmap, 'b c t h w -> 1 (b c) t h w')
weights = rearrange(weights, 'b o ... -> (b o) ...')
fmap = F.pad(fmap, self.padding, mode=self.pad_mode)
fmap = F.conv3d(fmap, weights, groups=b)
return rearrange(fmap, '1 (b o) ... -> b o ...', b=b)
# strided conv downsamples
class SpatialDownsample2x(Module):
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.conv = nn.Conv2d(dim,
dim_out,
kernel_size,
stride=2,
padding=kernel_size // 2)
def forward(self, x):
x = self.maybe_blur(x, space_only=True)
x = rearrange(x, 'b c t h w -> b t c h w')
x, ps = pack_one(x, '* c h w')
out = self.conv(x)
out = unpack_one(out, ps, '* c h w')
out = rearrange(out, 'b t c h w -> b c t h w')
return out
class TimeDownsample2x(Module):
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.time_causal_padding = (kernel_size - 1, 0)
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
def forward(self, x):
x = self.maybe_blur(x, time_only=True)
x = rearrange(x, 'b c t h w -> b h w c t')
x, ps = pack_one(x, '* c t')
x = F.pad(x, self.time_causal_padding)
out = self.conv(x)
out = unpack_one(out, ps, '* c t')
out = rearrange(out, 'b h w c t -> b c t h w')
return out
# depth to space upsamples
class SpatialUpsample2x(Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv, nn.SiLU(),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
x = rearrange(x, 'b c t h w -> b t c h w')
x, ps = pack_one(x, '* c h w')
out = self.net(x)
out = unpack_one(out, ps, '* c h w')
out = rearrange(out, 'b t c h w -> b c t h w')
return out
class TimeUpsample2x(Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * 2, 1)
self.net = nn.Sequential(conv, nn.SiLU(),
Rearrange('b (c p) t -> b c (t p)', p=2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, t = conv.weight.shape
conv_weight = torch.empty(o // 2, i, t)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
x = rearrange(x, 'b c t h w -> b h w c t')
x, ps = pack_one(x, '* c t')
out = self.net(x)
out = unpack_one(out, ps, '* c t')
out = rearrange(out, 'b h w c t -> b c t h w')
return out
# autoencoder - only best variant here offered, with causal conv 3d
def SameConv2d(dim_in, dim_out, kernel_size):
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
class CausalConv3d(Module):
@beartype
def __init__(self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode='constant',
**kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
stride = kwargs.pop('stride', 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad,
height_pad, time_pad, 0)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
return self.conv(x)
@beartype
def ResidualUnit(dim,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode: str = 'constant'):
net = Sequential(
CausalConv3d(dim, dim, kernel_size, pad_mode=pad_mode),
nn.ELU(),
nn.Conv3d(dim, dim, 1),
nn.ELU(),
SqueezeExcite(dim),
)
return Residual(net)
@beartype
class ResidualUnitMod(Module):
def __init__(self,
dim,
kernel_size: Union[int, Tuple[int, int, int]],
*,
dim_cond,
pad_mode: str = 'constant',
demod=True):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert height_kernel_size == width_kernel_size
self.to_cond = nn.Linear(dim_cond, dim)
self.conv = Conv3DMod(
dim=dim,
spatial_kernel=height_kernel_size,
time_kernel=time_kernel_size,
causal=True,
demod=demod,
pad_mode=pad_mode,
)
self.conv_out = nn.Conv3d(dim, dim, 1)
@beartype
def forward(
self,
x,
cond: Tensor,
):
res = x
cond = self.to_cond(cond)
x = self.conv(x, cond=cond)
x = F.elu(x)
x = self.conv_out(x)
x = F.elu(x)
return x + res
class CausalConvTranspose3d(Module):
def __init__(self, chan_in, chan_out,
kernel_size: Union[int, Tuple[int, int, int]], *, time_stride,
**kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
self.upsample_factor = time_stride
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
stride = (time_stride, 1, 1)
padding = (0, height_pad, width_pad)
self.conv = nn.ConvTranspose3d(chan_in,
chan_out,
kernel_size,
stride,
padding=padding,
**kwargs)
def forward(self, x):
assert x.ndim == 5
t = x.shape[2]
out = self.conv(x)
out = out[..., :(t * self.upsample_factor), :, :]
return out
# video tokenizer class
LossBreakdown = namedtuple(
'LossBreakdown',
[
'recon_loss',
'lfq_aux_loss',
'quantizer_loss_breakdown',
'perceptual_loss',
'adversarial_gen_loss',
'adaptive_adversarial_weight',
'multiscale_gen_losses',
'multiscale_gen_adaptive_weights',
],
)
DiscrLossBreakdown = namedtuple(
'DiscrLossBreakdown',
['discr_loss', 'multiscale_discr_losses', 'gradient_penalty'])
class VideoTokenizer(Module):
@beartype
def __init__(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]],
...] = ('residual', 'residual', 'residual'),
residual_conv_kernel_size=3,
num_codebooks=1,
codebook_size: Optional[int] = None,
channels=3,
init_dim=64,
max_dim=float('inf'),
dim_cond=None,
dim_cond_expansion_factor=4.0,
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
pad_mode: str = 'constant',
lfq_entropy_loss_weight=0.1,
lfq_commitment_loss_weight=1.0,
lfq_diversity_gamma=2.5,
quantizer_aux_loss_weight=1.0,
lfq_activation=nn.Identity(),
use_fsq=False,
fsq_levels: Optional[List[int]] = None,
attn_dim_head=32,
attn_heads=8,
attn_dropout=0.0,
linear_attn_dim_head=8,
linear_attn_heads=16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight=1e-1,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Tuple[Module, ...] = tuple(),
use_gan=True,
adversarial_loss_weight=1.0,
grad_penalty_loss_weight=10.0,
multiscale_adversarial_loss_weight=1.0,
flash_attn=True,
separate_first_frame_encoding=False,
):
super().__init__()
# for autosaving the config
_locals = locals()
_locals.pop('self', None)
_locals.pop('__class__', None)
self._configs = pickle.dumps(_locals)
# image size
self.channels = channels
self.image_size = image_size
# initial encoder
self.conv_in = CausalConv3d(channels,
init_dim,
input_conv_kernel_size,
pad_mode=pad_mode)
# whether to encode the first frame separately or not
self.conv_in_first_frame = nn.Identity()
self.conv_out_first_frame = nn.Identity()
if separate_first_frame_encoding:
self.conv_in_first_frame = SameConv2d(channels, init_dim,
input_conv_kernel_size[-2:])
self.conv_out_first_frame = SameConv2d(
init_dim, channels, output_conv_kernel_size[-2:])
self.separate_first_frame_encoding = separate_first_frame_encoding
# encoder and decoder layers
self.encoder_layers = ModuleList([])
self.decoder_layers = ModuleList([])
self.conv_out = CausalConv3d(init_dim,
channels,
output_conv_kernel_size,
pad_mode=pad_mode)
dim = init_dim
dim_out = dim
layer_fmap_size = image_size
time_downsample_factor = 1
has_cond_across_layers = []
for layer_def in layers:
layer_type, *layer_params = cast_tuple(layer_def)
has_cond = False
if layer_type == 'residual':
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
elif layer_type == 'consecutive_residual':
(num_consecutive, ) = layer_params
encoder_layer = Sequential(*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
])
decoder_layer = Sequential(*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
])
elif layer_type == 'cond_residual':
assert exists(
dim_cond
), 'dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned'
has_cond = True
encoder_layer = ResidualUnitMod(
dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor))
decoder_layer = ResidualUnitMod(
dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor))
dim_out = dim
elif layer_type == 'compress_space':
dim_out = safe_get_index(layer_params, 0)
dim_out = default(dim_out, dim * 2)
dim_out = min(dim_out, max_dim)
encoder_layer = SpatialDownsample2x(dim, dim_out)
decoder_layer = SpatialUpsample2x(dim_out, dim)
assert layer_fmap_size > 1
layer_fmap_size //= 2
elif layer_type == 'compress_time':
dim_out = safe_get_index(layer_params, 0)
dim_out = default(dim_out, dim * 2)
dim_out = min(dim_out, max_dim)
encoder_layer = TimeDownsample2x(dim, dim_out)
decoder_layer = TimeUpsample2x(dim_out, dim)
time_downsample_factor *= 2
elif layer_type == 'attend_space':
attn_kwargs = dict(dim=dim,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn)
encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim)))
decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim)))
elif layer_type == 'linear_attend_space':
linear_attn_kwargs = dict(dim=dim,
dim_head=linear_attn_dim_head,
heads=linear_attn_heads)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
Residual(FeedForward(dim)))
decoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
Residual(FeedForward(dim)))
elif layer_type == 'gateloop_time':
gateloop_kwargs = dict(use_heinsen=False)
encoder_layer = ToTimeSequence(
Residual(SimpleGateLoopLayer(dim=dim)))
decoder_layer = ToTimeSequence(
Residual(SimpleGateLoopLayer(dim=dim)))
elif layer_type == 'attend_time':
attn_kwargs = dict(
dim=dim,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
causal=True,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
decoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
elif layer_type == 'cond_attend_space':
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim)))
decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim)))
elif layer_type == 'cond_linear_attend_space':
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)))
decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)))
elif layer_type == 'cond_attend_time':
has_cond = True
attn_kwargs = dict(
dim=dim,
dim_cond=dim_cond,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
causal=True,
flash=flash_attn,
)
encoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
decoder_layer = Sequential(
Residual(TokenShift(TimeAttention(**attn_kwargs))),
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
)
else:
raise ValueError(f'unknown layer type {layer_type}')
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out
has_cond_across_layers.append(has_cond)
# add a final norm just before quantization layer
self.encoder_layers.append(
Sequential(
Rearrange('b c ... -> b ... c'),
nn.LayerNorm(dim),
Rearrange('b ... c -> b c ...'),
))
self.time_downsample_factor = time_downsample_factor
self.time_padding = time_downsample_factor - 1
self.fmap_size = layer_fmap_size
# use a MLP stem for conditioning, if needed
self.has_cond_across_layers = has_cond_across_layers
self.has_cond = any(has_cond_across_layers)
self.encoder_cond_in = nn.Identity()
self.decoder_cond_in = nn.Identity()
if has_cond:
self.dim_cond = dim_cond
self.encoder_cond_in = Sequential(
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)),
nn.SiLU())
self.decoder_cond_in = Sequential(
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)),
nn.SiLU())
# quantizer related
self.use_fsq = use_fsq
if not use_fsq:
assert exists(codebook_size) and not exists(
fsq_levels
), 'if use_fsq is set to False, `codebook_size` must be set (and not `fsq_levels`)'
# lookup free quantizer(s) - multiple codebooks is possible
# each codebook will get its own entropy regularization
self.quantizers = LFQ(
dim=dim,
codebook_size=codebook_size,
num_codebooks=num_codebooks,
entropy_loss_weight=lfq_entropy_loss_weight,
commitment_loss_weight=lfq_commitment_loss_weight,
diversity_gamma=lfq_diversity_gamma,
)
else:
assert (
not exists(codebook_size) and exists(fsq_levels)
), 'if use_fsq is set to True, `fsq_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels'
self.quantizers = FSQ(fsq_levels,
dim=dim,
num_codebooks=num_codebooks)
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
# dummy loss
self.register_buffer('zero', torch.tensor(0.0), persistent=False)
# perceptual loss related
use_vgg = channels in {1, 3, 4} and perceptual_loss_weight > 0.0
self.vgg = None
self.perceptual_loss_weight = perceptual_loss_weight
if use_vgg:
if not exists(vgg):
vgg = torchvision.models.vgg16(weights=vgg_weights)
vgg.classifier = Sequential(*vgg.classifier[:-2])
self.vgg = vgg
self.use_vgg = use_vgg
# main flag for whether to use GAN at all
self.use_gan = use_gan
# discriminator
discr_kwargs = default(
discr_kwargs,
dict(dim=dim,
image_size=image_size,
channels=channels,
max_dim=512))
self.discr = Discriminator(**discr_kwargs)
self.adversarial_loss_weight = adversarial_loss_weight
self.grad_penalty_loss_weight = grad_penalty_loss_weight
self.has_gan = use_gan and adversarial_loss_weight > 0.0
# multi-scale discriminators
self.has_multiscale_gan = use_gan and multiscale_adversarial_loss_weight > 0.0
self.multiscale_discrs = ModuleList([*multiscale_discrs])
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
self.has_multiscale_discrs = (use_gan and
multiscale_adversarial_loss_weight > 0.0
and len(multiscale_discrs) > 0)
@property
def device(self):
return self.zero.device
@classmethod
def init_and_load_from(cls, path, strict=True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location='cpu')
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
config = pickle.loads(pkg['config'])
tokenizer = cls(**config)
tokenizer.load(path, strict=strict)
return tokenizer
def parameters(self):
return [
*self.conv_in.parameters(),
*self.conv_in_first_frame.parameters(),
*self.conv_out_first_frame.parameters(),
*self.conv_out.parameters(),
*self.encoder_layers.parameters(),
*self.decoder_layers.parameters(),
*self.encoder_cond_in.parameters(),
*self.decoder_cond_in.parameters(),
*self.quantizers.parameters(),
]
def discr_parameters(self):
return self.discr.parameters()
def copy_for_eval(self):
device = self.device
vae_copy = copy.deepcopy(self.cpu())
maybe_del_attr_(vae_copy, 'discr')
maybe_del_attr_(vae_copy, 'vgg')
maybe_del_attr_(vae_copy, 'multiscale_discrs')
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
def save(self, path, overwrite=True):
path = Path(path)
assert overwrite or not path.exists(), f'{str(path)} already exists'
pkg = dict(model_state_dict=self.state_dict(),
version=__version__,
config=self._configs)
torch.save(pkg, str(path))
def load(self, path, strict=True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
state_dict = pkg.get('model_state_dict')
version = pkg.get('version')
assert exists(state_dict)
if exists(version):
print(f'loading checkpointed tokenizer from version {version}')
self.load_state_dict(state_dict, strict=strict)
@beartype
def encode(self,
video: Tensor,
quantize=False,
cond: Optional[Tensor] = None,
video_contains_first_frame=True):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# whether to pad video or not
if video_contains_first_frame:
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [
torch.Size([self.time_padding]),
torch.Size([]),
torch.Size([video_len - 1])
]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'
if exists(cond):
assert cond.shape == (video.shape[0], self.dim_cond)
cond = self.encoder_cond_in(cond)
cond_kwargs = dict(cond=cond)
# initial conv
# taking into account whether to encode first frame separately
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape,
'b c * h w')
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
if encode_first_frame_separately:
video, _ = pack([first_frame, video], 'b c * h w')
video = pad_at_dim(video, (self.time_padding, 0), dim=2)
# encoder layers
for fn, has_cond in zip(self.encoder_layers,
self.has_cond_across_layers):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
video = fn(video, **layer_kwargs)
maybe_quantize = identity if not quantize else self.quantizers
return maybe_quantize(video)
@beartype
def decode_from_code_indices(self,
codes: Tensor,
cond: Optional[Tensor] = None,
video_contains_first_frame=True):
assert codes.dtype in (torch.long, torch.int32)
if codes.ndim == 2:
video_code_len = codes.shape[-1]
assert divisible_by(
video_code_len, self.fmap_size**2
), f'flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})'
codes = rearrange(codes,
'b (f h w) -> b f h w',
h=self.fmap_size,
w=self.fmap_size)
quantized = self.quantizers.indices_to_codes(codes)
return self.decode(
quantized,
cond=cond,
video_contains_first_frame=video_contains_first_frame)
@beartype
def decode(self,
quantized: Tensor,
cond: Optional[Tensor] = None,
video_contains_first_frame=True):
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
batch = quantized.shape[0]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'
if exists(cond):
assert cond.shape == (batch, self.dim_cond)
cond = self.decoder_cond_in(cond)
cond_kwargs = dict(cond=cond)
# decoder layers
x = quantized
for fn, has_cond in zip(self.decoder_layers,
reversed(self.has_cond_across_layers)):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
x = fn(x, **layer_kwargs)
# to pixels
if decode_first_frame_separately:
left_pad, xff, x = (
x[:, :, :self.time_padding],
x[:, :, self.time_padding],
x[:, :, (self.time_padding + 1):],
)
out = self.conv_out(x)
outff = self.conv_out_first_frame(xff)
video, _ = pack([outff, out], 'b c * h w')
else:
video = self.conv_out(x)
# if video were padded, remove padding
if video_contains_first_frame:
video = video[:, :, self.time_padding:]
return video
@torch.no_grad()
def tokenize(self, video):
self.eval()
return self.forward(video, return_codes=True)
@beartype
def forward(
self,
video_or_images: Tensor,
cond: Optional[Tensor] = None,
return_loss=False,
return_codes=False,
return_recon=False,
return_discr_loss=False,
return_recon_loss_only=False,
apply_gradient_penalty=True,
video_contains_first_frame=True,
adversarial_loss_weight=None,
multiscale_adversarial_loss_weight=None,
):
adversarial_loss_weight = default(adversarial_loss_weight,
self.adversarial_loss_weight)
multiscale_adversarial_loss_weight = default(
multiscale_adversarial_loss_weight,
self.multiscale_adversarial_loss_weight)
assert (return_loss + return_codes + return_discr_loss) <= 1
assert video_or_images.ndim in {4, 5}
assert video_or_images.shape[-2:] == (self.image_size, self.image_size)
# accept images for image pretraining (curriculum learning from images to video)
is_image = video_or_images.ndim == 4
if is_image:
video = rearrange(video_or_images, 'b c ... -> b c 1 ...')
video_contains_first_frame = True
else:
video = video_or_images
batch, channels, frames = video.shape[:3]
assert divisible_by(
frames - int(video_contains_first_frame),
self.time_downsample_factor
), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}'
# encoder
x = self.encode(video,
cond=cond,
video_contains_first_frame=video_contains_first_frame)
# lookup free quantization
if self.use_fsq:
quantized, codes = self.quantizers(x)
aux_losses = self.zero
quantizer_loss_breakdown = None
else:
(quantized, codes,
aux_losses), quantizer_loss_breakdown = self.quantizers(
x, return_loss_breakdown=True)
if return_codes and not return_recon:
return codes
# decoder
recon_video = self.decode(
quantized,
cond=cond,
video_contains_first_frame=video_contains_first_frame)
if return_codes:
return codes, recon_video
# reconstruction loss
if not (return_loss or return_discr_loss or return_recon_loss_only):
return recon_video
recon_loss = F.mse_loss(video, recon_video)
# for validation, only return recon loss
if return_recon_loss_only:
return recon_loss, recon_video
# gan discriminator loss
if return_discr_loss:
assert self.has_gan
assert exists(self.discr)
# pick a random frame for image discriminator
frame_indices = torch.randn((batch, frames)).topk(1,
dim=-1).indices
real = pick_video_frame(video, frame_indices)
if apply_gradient_penalty:
real = real.requires_grad_()
fake = pick_video_frame(recon_video, frame_indices)
real_logits = self.discr(real)
fake_logits = self.discr(fake.detach())
discr_loss = hinge_discr_loss(fake_logits, real_logits)
# multiscale discriminators
multiscale_discr_losses = []
if self.has_multiscale_discrs:
for discr in self.multiscale_discrs:
multiscale_real_logits = discr(video)
multiscale_fake_logits = discr(recon_video.detach())
multiscale_discr_loss = hinge_discr_loss(
multiscale_fake_logits, multiscale_real_logits)
multiscale_discr_losses.append(multiscale_discr_loss)
else:
multiscale_discr_losses.append(self.zero)
# gradient penalty
if apply_gradient_penalty:
gradient_penalty_loss = gradient_penalty(real, real_logits)
else:
gradient_penalty_loss = self.zero
# total loss
total_loss = (
discr_loss +
gradient_penalty_loss * self.grad_penalty_loss_weight +
sum(multiscale_discr_losses) *
self.multiscale_adversarial_loss_weight)
discr_loss_breakdown = DiscrLossBreakdown(discr_loss,
multiscale_discr_losses,
gradient_penalty_loss)
return total_loss, discr_loss_breakdown
# perceptual loss
if self.use_vgg:
frame_indices = torch.randn((batch, frames)).topk(1,
dim=-1).indices
input_vgg_input = pick_video_frame(video, frame_indices)
recon_vgg_input = pick_video_frame(recon_video, frame_indices)
if channels == 1:
input_vgg_input = repeat(input_vgg_input,
'b 1 h w -> b c h w',
c=3)
recon_vgg_input = repeat(recon_vgg_input,
'b 1 h w -> b c h w',
c=3)
elif channels == 4:
input_vgg_input = input_vgg_input[:, :3]
recon_vgg_input = recon_vgg_input[:, :3]
input_vgg_feats = self.vgg(input_vgg_input)
recon_vgg_feats = self.vgg(recon_vgg_input)
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
else:
perceptual_loss = self.zero
# get gradient with respect to perceptual loss for last decoder layer
# needed for adaptive weighting
last_dec_layer = self.conv_out.conv.weight
norm_grad_wrt_perceptual_loss = None
if self.training and self.use_vgg and (self.has_gan
or self.has_multiscale_discrs):
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_dec_layer).norm(p=2)
# per-frame image discriminator
recon_video_frames = None
if self.has_gan:
frame_indices = torch.randn((batch, frames)).topk(1,
dim=-1).indices
recon_video_frames = pick_video_frame(recon_video, frame_indices)
fake_logits = self.discr(recon_video_frames)
gen_loss = hinge_gen_loss(fake_logits)
adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
gen_loss, last_dec_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():
adaptive_weight = 1.0
else:
gen_loss = self.zero
adaptive_weight = 0.0
# multiscale discriminator losses
multiscale_gen_losses = []
multiscale_gen_adaptive_weights = []
if self.has_multiscale_gan and self.has_multiscale_discrs:
if not exists(recon_video_frames):
recon_video_frames = pick_video_frame(recon_video,
frame_indices)
for discr in self.multiscale_discrs:
fake_logits = recon_video_frames
multiscale_gen_loss = hinge_gen_loss(fake_logits)
multiscale_gen_losses.append(multiscale_gen_loss)
multiscale_adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
multiscale_gen_loss, last_dec_layer).norm(p=2)
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-5)
multiscale_adaptive_weight.clamp_(max=1e3)
multiscale_gen_adaptive_weights.append(
multiscale_adaptive_weight)
# calculate total loss
total_loss = (recon_loss +
aux_losses * self.quantizer_aux_loss_weight +
perceptual_loss * self.perceptual_loss_weight +
gen_loss * adaptive_weight * adversarial_loss_weight)
if self.has_multiscale_discrs:
weighted_multiscale_gen_losses = sum(
loss * weight for loss, weight in zip(
multiscale_gen_losses, multiscale_gen_adaptive_weights))
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
# loss breakdown
loss_breakdown = LossBreakdown(
recon_loss,
aux_losses,
quantizer_loss_breakdown,
perceptual_loss,
gen_loss,
adaptive_weight,
multiscale_gen_losses,
multiscale_gen_adaptive_weights,
)
return total_loss, loss_breakdown
# main class
class MagViT2(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import \
DiagonalGaussianDistribution
from .base import AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log['kl_loss'] = kl_loss
return z, log
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn.functional as F
from torch import nn
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(
predicted_indices: torch.Tensor,
num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices,
num_centroids).float().reshape(-1, num_centroids)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""
from typing import List, Optional
import torch
import torch.nn as nn
from einops import pack, rearrange, unpack
from torch import Tensor, int32
from torch.cuda.amp import autocast
from torch.nn import Module
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# tensor helpers
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
# main class
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer('_levels', _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]),
dim=0,
dtype=int32)
self.register_buffer('_basis', _basis, persistent=False)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim,
num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(
self.dim,
effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(
effective_codebook_dim,
self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(
self.codebook_size),
project_out=False)
self.register_buffer('implicit_codebook',
implicit_codebook,
persistent=False)
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 +
int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
@autocast(enabled=False)
def forward(self, z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
is_img_or_video = z.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[
-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
# reconstitute image or video dimensions
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
return out, indices
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
from collections import namedtuple
from math import ceil, log2
import torch
import torch.nn.functional as F
from einops import pack, rearrange, reduce, unpack
from torch import einsum, nn
from torch.cuda.amp import autocast
from torch.nn import Module
# constants
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
LossBreakdown = namedtuple(
'LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# entropy
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# class
class LFQ(Module):
def __init__(
self,
*,
dim=None,
codebook_size=None,
entropy_loss_weight=0.1,
commitment_loss_weight=0.25,
diversity_gamma=1.0,
straight_through_activation=nn.Identity(),
num_codebooks=1,
keep_num_codebooks_dim=None,
codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
):
super().__init__()
# some assert validations
assert exists(dim) or exists(
codebook_size
), 'either dim or codebook_size must be specified for LFQ'
assert (
not exists(codebook_size) or log2(codebook_size).is_integer()
), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
codebook_size = default(codebook_size, lambda: 2**dim)
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.project_in = nn.Linear(
dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(
codebook_dims, dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim,
num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
# straight through activation
self.activation = straight_through_activation
# entropy aux loss related weights
assert 0 < frac_per_sample_entropy <= 1.0
self.frac_per_sample_entropy = frac_per_sample_entropy
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
# codebook scale
self.codebook_scale = codebook_scale
# commitment loss
self.commitment_loss_weight = commitment_loss_weight
# for no auxiliary loss, during inference
self.register_buffer('mask', 2**torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.0), persistent=False)
# codes
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer('codebook', codebook, persistent=False)
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self):
return self.codebook.dtype
def indices_to_codes(self, indices, project_out=True):
is_img_or_video = indices.ndim >= (3 +
int(self.keep_num_codebooks_dim))
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... -> ... 1')
# indices to codes, which are bits of either -1 or 1
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = self.bits_to_codes(bits)
codes = rearrange(codes, '... c d -> ... (c d)')
# whether to project codes out to original dimensions
# if the input feature dimensions were not log2(codebook size)
if project_out:
codes = self.project_out(codes)
# rearrange codes back to original shape
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
@autocast(enabled=False)
def forward(
self,
x,
inv_temperature=100.0,
return_loss_breakdown=False,
mask=None,
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
x = x.float()
is_img_or_video = x.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
x = rearrange(x, 'b d ... -> b ... d')
x, ps = pack_one(x, 'b * d')
assert x.shape[
-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
x = self.project_in(x)
# split out number of codebooks
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
# quantize by eq 3.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)
# use straight-through gradients (optionally with custom activation fn) if training
if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized
# calculate indices
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c',
'sum')
# entropy aux loss
if self.training:
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input,
self.codebook)
prob = (-distance * inv_temperature).softmax(dim=-1)
# account for mask
if exists(mask):
prob = prob[mask]
else:
prob = rearrange(prob, 'b n ... -> (b n) ...')
# whether to only use a fraction of probs, for reducing memory
if self.frac_per_sample_entropy < 1.0:
num_tokens = prob.shape[0]
num_sampled_tokens = int(num_tokens *
self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(
dim=-1) < num_sampled_tokens
per_sample_probs = prob[rand_mask]
else:
per_sample_probs = prob
# calculate per sample entropy
per_sample_entropy = entropy(per_sample_probs).mean()
# distribution over all available tokens in the batch
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
codebook_entropy = entropy(avg_prob).mean()
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
# commit loss
if self.training:
commit_loss = F.mse_loss(original_input,
quantized.detach(),
reduction='none')
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# merge back codebook dim
x = rearrange(x, 'b n c d -> b n (c d)')
# project out to feature dimension if needed
x = self.project_out(x)
# reconstitute image or video dimensions
if is_img_or_video:
x = unpack_one(x, ps, 'b * d')
x = rearrange(x, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
# whether to remove single codebook dim
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
ret = Return(x, indices, aux_loss)
if not return_loss_breakdown:
return ret
return ret, LossBreakdown(per_sample_entropy, codebook_entropy,
commit_loss)
import logging
from abc import abstractmethod
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from .base import AbstractRegularizer, measure_perplexity
logpy = logging.getLogger(__name__)
class AbstractQuantizer(AbstractRegularizer):
def __init__(self):
super().__init__()
# Define these in your init
# shape (N,)
self.used: Optional[torch.Tensor]
self.re_embed: int
self.unknown_index: Union[Literal['random'], int]
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, 'You need to define used indices for remap'
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == 'random':
new[unknown] = torch.randint(
0, self.re_embed,
size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, 'You need to define used indices for remap'
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
@abstractmethod
def get_codebook_entry(
self,
indices: torch.Tensor,
shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
yield from self.parameters()
class GumbelQuantizer(AbstractQuantizer):
"""
credit to @karpathy:
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens: int,
embedding_dim: int,
n_embed: int,
straight_through: bool = True,
kl_weight: float = 5e-4,
temp_init: float = 1.0,
remap: Optional[str] = None,
unknown_index: str = 'random',
loss_key: str = 'loss/vq',
) -> None:
super().__init__()
self.loss_key = loss_key
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.remap = remap
if self.remap is not None:
self.register_buffer('used', torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == 'extra':
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == 'random' or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f'Remapping {self.n_embed} indices to {self.re_embed} indices. '
f'Using {self.unknown_index} for unknown indices.')
def forward(self,
z: torch.Tensor,
temp: Optional[float] = None,
return_logits: bool = False) -> Tuple[torch.Tensor, Dict]:
# force hard = True when we are in eval mode, as we must quantize.
# actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
out_dict = {}
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot,
self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(
qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
out_dict[self.loss_key] = diff
ind = soft_one_hot.argmax(dim=1)
out_dict['indices'] = ind
if self.remap is not None:
ind = self.remap_to_used(ind)
if return_logits:
out_dict['logits'] = logits
return z_q, out_dict
def get_codebook_entry(self, indices, shape):
# TODO: shape not yet optional
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = F.one_hot(indices,
num_classes=self.n_embed).permute(0, 3, 1,
2).float()
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
return z_q
class VectorQuantizer(AbstractQuantizer):
"""
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term,
beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(
self,
n_e: int,
e_dim: int,
beta: float = 0.25,
remap: Optional[str] = None,
unknown_index: str = 'random',
sane_index_shape: bool = False,
log_perplexity: bool = False,
embedding_weight_norm: bool = False,
loss_key: str = 'loss/vq',
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.loss_key = loss_key
if not embedding_weight_norm:
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e,
1.0 / self.n_e)
else:
self.embedding = torch.nn.utils.weight_norm(nn.Embedding(
self.n_e, self.e_dim),
dim=1)
self.remap = remap
if self.remap is not None:
self.register_buffer('used', torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_e
if unknown_index == 'extra':
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == 'random' or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f'Remapping {self.n_e} indices to {self.re_embed} indices. '
f'Using {self.unknown_index} for unknown indices.')
self.sane_index_shape = sane_index_shape
self.log_perplexity = log_perplexity
def forward(
self,
z: torch.Tensor,
) -> Tuple[torch.Tensor, Dict]:
do_reshape = z.ndim == 4
if do_reshape:
# # reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
else:
assert z.ndim < 4, 'No reshaping strategy for inputs > 4 dimensions defined'
z = z.contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (torch.sum(z_flattened**2, dim=1, keepdim=True) +
torch.sum(self.embedding.weight**2, dim=1) -
2 * torch.einsum('bd,dn->bn', z_flattened,
rearrange(self.embedding.weight, 'n d -> d n')))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss_dict = {}
if self.log_perplexity:
perplexity, cluster_usage = measure_perplexity(
min_encoding_indices.detach(), self.n_e)
loss_dict.update({
'perplexity': perplexity,
'cluster_usage': cluster_usage
})
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean(
(z_q - z.detach())**2)
loss_dict[self.loss_key] = loss
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
if do_reshape:
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1,
1) # flatten
if self.sane_index_shape:
if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
else:
min_encoding_indices = rearrange(min_encoding_indices,
'(b s) 1 -> b s',
b=z_q.shape[0])
loss_dict['min_encoding_indices'] = min_encoding_indices
return z_q, loss_dict
def get_codebook_entry(
self,
indices: torch.Tensor,
shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
assert shape is not None, 'Need to give shape for remap'
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
super().__init__()
self.decay = decay
self.eps = eps
weight = torch.randn(num_tokens, codebook_dim)
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens),
requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
self.update = True
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size,
alpha=1 - self.decay)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg,
alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (self.cluster_size +
self.eps) / (n + num_tokens * self.eps) * n
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
self.weight.data.copy_(embed_normalized)
class EMAVectorQuantizer(AbstractQuantizer):
def __init__(
self,
n_embed: int,
embedding_dim: int,
beta: float,
decay: float = 0.99,
eps: float = 1e-5,
remap: Optional[str] = None,
unknown_index: str = 'random',
loss_key: str = 'loss/vq',
):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.loss_key = loss_key
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim,
decay, eps)
self.remap = remap
if self.remap is not None:
self.register_buffer('used', torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == 'extra':
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == 'random' or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f'Remapping {self.n_embed} indices to {self.re_embed} indices. '
f'Using {self.unknown_index} for unknown indices.')
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
z = rearrange(z, 'b c h w -> b h w c')
z_flattened = z.reshape(-1, self.codebook_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened.pow(2).sum(dim=1, keepdim=True) +
self.embedding.weight.pow(2).sum(dim=1) -
2 * torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs *
torch.log(avg_probs + 1e-10)))
if self.training and self.embedding.update:
# EMA cluster size
encodings_sum = encodings.sum(0)
self.embedding.cluster_size_ema_update(encodings_sum)
# EMA embedding average
embed_sum = encodings.transpose(0, 1) @ z_flattened
self.embedding.embed_avg_ema_update(embed_sum)
# normalize embed_avg and update weight
self.embedding.weight_update(self.num_tokens)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, 'b h w c -> b c h w')
out_dict = {
self.loss_key: loss,
'encodings': encodings,
'encoding_indices': encoding_indices,
'perplexity': perplexity,
}
return z_q, out_dict
class VectorQuantizerWithInputProjection(VectorQuantizer):
def __init__(
self,
input_dim: int,
n_codes: int,
codebook_dim: int,
beta: float = 1.0,
output_dim: Optional[int] = None,
**kwargs,
):
super().__init__(n_codes, codebook_dim, beta, **kwargs)
self.proj_in = nn.Linear(input_dim, codebook_dim)
self.output_dim = output_dim
if output_dim is not None:
self.proj_out = nn.Linear(codebook_dim, output_dim)
else:
self.proj_out = nn.Identity()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
rearr = False
in_shape = z.shape
if z.ndim > 3:
rearr = self.output_dim is not None
z = rearrange(z, 'b c ... -> b (...) c')
z = self.proj_in(z)
z_q, loss_dict = super().forward(z)
z_q = self.proj_out(z_q)
if rearr:
if len(in_shape) == 4:
z_q = rearrange(z_q, 'b (h w) c -> b c h w ', w=in_shape[-1])
elif len(in_shape) == 5:
z_q = rearrange(z_q,
'b (t h w) c -> b c t h w ',
w=in_shape[-1],
h=in_shape[-2])
else:
raise NotImplementedError(
f'rearranging not available for {len(in_shape)}-dimensional input.'
)
return z_q, loss_dict
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
AttnBlock, Decoder,
MemoryEfficientAttnBlock,
ResnetBlock)
from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
timestep_embedding)
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy='learned',
**kwargs,
):
super().__init__(out_channels=out_channels,
dropout=dropout,
*args,
**kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == 'fixed':
self.register_buffer('mix_factor', torch.Tensor([alpha]))
elif self.merge_strategy == 'learned':
self.register_parameter('mix_factor',
torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f'unknown merge strategy {self.merge_strategy}')
def get_alpha(self, bs):
if self.merge_strategy == 'fixed':
return self.mix_factor
elif self.merge_strategy == 'learned':
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
if timesteps is None:
timesteps = self.timesteps
b, c, h, w = x.shape
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
video_kernel_size=3,
*args,
**kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps, skip_video=False):
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, 'b c t h w -> (b t) c h w')
class VideoBlock(AttnBlock):
def __init__(self,
in_channels: int,
alpha: float = 0,
merge_strategy: str = 'learned'):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode='softmax',
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == 'fixed':
self.register_buffer('mix_factor', torch.Tensor([alpha]))
elif self.merge_strategy == 'learned':
self.register_parameter('mix_factor',
torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f'unknown merge strategy {self.merge_strategy}')
def forward(self, x, timesteps, skip_video=False):
if skip_video:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, 'b c h w -> b (h w) c')
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, 'b t -> (b t)')
t_emb = timestep_embedding(num_frames,
self.in_channels,
repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(self, ):
if self.merge_strategy == 'fixed':
return self.mix_factor
elif self.merge_strategy == 'learned':
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(
f'unknown merge strategy {self.merge_strategy}')
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
def __init__(self,
in_channels: int,
alpha: float = 0,
merge_strategy: str = 'learned'):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode='softmax-xformers',
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == 'fixed':
self.register_buffer('mix_factor', torch.Tensor([alpha]))
elif self.merge_strategy == 'learned':
self.register_parameter('mix_factor',
torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f'unknown merge strategy {self.merge_strategy}')
def forward(self, x, timesteps, skip_time_block=False):
if skip_time_block:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, 'b c h w -> b (h w) c')
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, 'b t -> (b t)')
t_emb = timestep_embedding(num_frames,
self.in_channels,
repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(self, ):
if self.merge_strategy == 'fixed':
return self.mix_factor
elif self.merge_strategy == 'learned':
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(
f'unknown merge strategy {self.merge_strategy}')
def make_time_attn(
in_channels,
attn_type='vanilla',
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = 'learned',
):
assert attn_type in [
'vanilla',
'vanilla-xformers',
], f'attn_type {attn_type} not supported for spatio-temporal attention'
print(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == 'vanilla-xformers':
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}'
)
attn_type = 'vanilla'
if attn_type == 'vanilla':
assert attn_kwargs is None
return partialclass(VideoBlock,
in_channels,
alpha=alpha,
merge_strategy=merge_strategy)
elif attn_type == 'vanilla-xformers':
print(
f'building MemoryEfficientAttnBlock with {in_channels} in_channels...'
)
return partialclass(
MemoryEfficientVideoBlock,
in_channels,
alpha=alpha,
merge_strategy=merge_strategy,
)
else:
return NotImplementedError()
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ['all', 'conv-only', 'attn-only']
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = 'learned',
time_mode: str = 'conv-only',
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f'time_mode parameter has to be in {self.available_time_modes}'
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == 'attn-only':
raise NotImplementedError('TODO')
else:
return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
def _make_attn(self) -> Callable:
if self.time_mode not in ['conv-only', 'only-last-conv']:
return partialclass(
make_time_attn,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_attn()
def _make_conv(self) -> Callable:
if self.time_mode != 'attn-only':
return partialclass(AE3DConv,
video_kernel_size=self.video_kernel_size)
else:
return Conv2DWrapper
def _make_resblock(self) -> Callable:
if self.time_mode not in ['attn-only', 'only-last-conv']:
return partialclass(
VideoResBlock,
video_kernel_size=self.video_kernel_size,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_resblock()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment