Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
import logging
import math
import random
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed
import torch.nn as nn
from einops import rearrange
from packaging import version
from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.cp_enc_dec import _conv_gather, _conv_split
from ..modules.ema import LitEma
from ..util import (default, get_context_parallel_group,
get_context_parallel_group_rank, get_nested_attribute,
get_obj_from_str, initialize_context_parallel,
instantiate_from_config, is_context_parallel_initialized)
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = 'jpg',
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(
f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if version.parse(torch.__version__) >= version.parse('2.0.0'):
self.automatic_optimization = False
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
if isinstance(ckpt, str):
ckpt = {
'target': 'sgm.modules.checkpoint.CheckpointEngine',
'params': {
'ckpt_path': ckpt
},
}
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f'{context}: Restored training weights')
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
'encode()-method of abstract base class called')
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
'decode()-method of abstract base class called')
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg['target'])(params,
lr=lr,
**cfg.get('params', dict()))
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.automatic_optimization = False # pytorch lightning
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
regularizer_config)
self.optimizer_config = default(optimizer_config,
{'target': 'torch.optim.Adam'})
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(
self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn(
'Checkpoint path is deprecated, use `checkpoint_egnine` instead'
)
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = []
if hasattr(self.loss, 'get_trainable_autoencoder_parameters'):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, 'get_trainable_parameters'):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
if hasattr(self.loss, 'get_trainable_parameters'):
params = list(
self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x, **kwargs)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self,
batch: dict,
batch_idx: int,
optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {
key: batch[key]
for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, 'forward_keys'):
extra_info = {
'z': z,
'optimizer_idx': optimizer_idx,
'global_step': self.global_step,
'last_layer': self.get_last_layer(),
'split': 'train',
'regularization_log': regularization_log,
'autoencoder': self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {'train/loss/rec': aeloss.detach()}
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
'loss',
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return discloss
else:
raise NotImplementedError(f'Unknown optimizer {optimizer_idx}')
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(batch,
batch_idx,
optimizer_idx=optimizer_idx)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch,
batch_idx,
postfix='_ema')
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self,
batch: dict,
batch_idx: int,
postfix: str = '') -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if hasattr(self.loss, 'forward_keys'):
extra_info = {
'z': z,
'optimizer_idx': 0,
'global_step': self.global_step,
'last_layer': self.get_last_layer(),
'split': 'val' + postfix,
'regularization_log': regularization_log,
'autoencoder': self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f'val{postfix}/loss/rec': aeloss.detach()}
full_log_dict = log_dict_ae
if 'optimizer_idx' in extra_info:
extra_info['optimizer_idx'] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f'val{postfix}/loss/rec',
log_dict_ae[f'val{postfix}/loss/rec'],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
def get_param_groups(
self, parameter_names: List[List[str]],
optimizer_args: List[dict]) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(
f'Did not find parameters for pattern {pattern_}')
params.extend(pattern_params)
groups.append({'params': params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args)
logpy.info(
f'Number of trainable autoencoder parameters: {num_ae_params:,}'
)
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args)
logpy.info(
f'Number of trainable discriminator parameters: {num_disc_params:,}'
)
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self,
batch: dict,
additional_log_kwargs: Optional[Dict] = None,
**kwargs) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({
key: batch[key]
for key in self.additional_decode_keys.intersection(batch)
})
_, xrec, _ = self(x, **additional_decode_kwargs)
log['inputs'] = x
log['reconstructions'] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log['diff'] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log['diff_boost'] = 2.0 * torch.clamp(self.diff_boost_factor * diff,
0.0, 1.0) - 1
if hasattr(self.loss, 'log_images'):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log['reconstructions_ema'] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log['diff_ema'] = 2.0 * diff_ema - 1.0
log['diff_boost_ema'] = 2.0 * torch.clamp(
self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = 'reconstructions-' + '-'.join([
f'{key}={additional_log_kwargs[key]}'
for key in additional_log_kwargs
])
log[log_str] = xrec_add
return log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop('max_batch_size', None)
ddconfig = kwargs.pop('ddconfig')
ckpt_path = kwargs.pop('ckpt_path', None)
ckpt_engine = kwargs.pop('ckpt_engine', None)
super().__init__(
encoder_config={
'target': 'sgm.modules.diffusionmodules.model.Encoder',
'params': ddconfig,
},
decoder_config={
'target': 'sgm.modules.diffusionmodules.model.Decoder',
'params': ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig['double_z']) * ddconfig['z_channels'],
(1 + ddconfig['double_z']) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs:(i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs:(i_batch + 1) *
bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return
import os
class VideoAutoencodingEngine(AutoencodingEngine):
def __init__(
self,
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list] = (),
image_video_weights=[1, 1],
only_train_decoder=False,
context_parallel_size=0,
**kwargs,
):
super().__init__(**kwargs)
self.context_parallel_size = context_parallel_size
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self,
batch: dict,
additional_log_kwargs: Optional[Dict] = None,
**kwargs) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
if self.context_parallel_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.context_parallel_size)
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank(
) * self.context_parallel_size
torch.distributed.broadcast(batch,
src=global_src_rank,
group=get_context_parallel_group())
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch
return batch[self.input_key]
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
if os.environ.get('SKIP_LOAD', False):
print(f'skip loading from {path}')
return
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print('Missing keys: ', missing_keys)
print('Unexpected keys: ', unexpected_keys)
print(f'Restored from {path}')
from .encoders.modules import GeneralConditioner
UNCONDITIONAL_CONFIG = {
'target': 'sgm.modules.GeneralConditioner',
'params': {
'emb_models': []
},
}
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse('2.0.0'):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
'enable_math': True,
'enable_flash': False,
'enable_mem_efficient': False,
},
SDPBackend.FLASH_ATTENTION: {
'enable_math': False,
'enable_flash': True,
'enable_mem_efficient': False,
},
SDPBackend.EFFICIENT_ATTENTION: {
'enable_math': False,
'enable_flash': False,
'enable_mem_efficient': True,
},
None: {
'enable_math': True,
'enable_flash': True,
'enable_mem_efficient': True
},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f'No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, '
f'you are using PyTorch {torch.__version__}. You might want to consider upgrading.'
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(k[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_cp)
v = repeat(v[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_cp)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, 'b h n d -> b n (h d)', h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
**kwargs):
super().__init__()
print(
f'Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using '
f'{heads} heads with a dimension of {dim_head}.')
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=self.attention_op)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention, # vanilla attention
'softmax-xformers': MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode='softmax',
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != 'softmax' and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}'
)
attn_mode = 'softmax'
elif attn_mode == 'softmax' and not SDP_IS_AVAILABLE:
print(
'We do not support vanilla attention anymore, as it is too expensive. Sorry.'
)
if not XFORMERS_IS_AVAILABLE:
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print('Falling back to xformers efficient attention.')
attn_mode = 'softmax-xformers'
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse('2.0.0'):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f'{self.__class__.__name__} is using checkpointing')
def forward(self,
x,
context=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0):
kwargs = {'x': x}
if context is not None:
kwargs.update({'context': context})
if additional_tokens is not None:
kwargs.update({'additional_tokens': additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update({
'n_times_crossframe_attn_in_self':
n_times_crossframe_attn_in_self
})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self,
x,
context=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0):
x = (self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn else 0,
) + x)
x = self.attn2(self.norm2(x),
context=context,
additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention, # vanilla attention
'softmax-xformers':
MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode='softmax',
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type='softmax',
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
print(
f'constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads'
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim,
(list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f'WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, '
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), 'need homogenous context_dim to match depth automatically'
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
__all__ = [
'GeneralLPIPSWithDiscriminator',
'LatentLPIPS',
]
from .discriminator_loss import GeneralLPIPSWithDiscriminator
from .lpips import LatentLPIPS
from .video_loss import VideoAutoencoderLoss
from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision
from einops import rearrange
from matplotlib import colormaps
from matplotlib import pyplot as plt
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = 'hinge',
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, Dict[str, float]] = None,
additional_log_keys: Optional[List[str]] = None,
discriminator_config: Optional[Dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f'running with dims={dims}. This means that for perceptual loss '
f'calculation, the LPIPS loss will be applied to each frame '
f'independently.')
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ['hinge', 'vanilla']
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.full((), logvar_init),
requires_grad=learn_logvar)
self.learn_logvar = learn_logvar
discriminator_config = default(
discriminator_config,
{
'target':
'sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator',
'params': {
'input_nc': disc_in_channels,
'n_layers': disc_num_layers,
'use_actnorm': False,
},
},
)
self.discriminator = instantiate_from_config(
discriminator_config).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
self.forward_keys = [
'optimizer_idx',
'global_step',
'last_layer',
'split',
'regularization_log',
]
self.additional_log_keys = set(default(additional_log_keys, []))
self.additional_log_keys.update(set(
self.regularization_weights.keys()))
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
if self.learn_logvar:
yield self.logvar
yield from ()
@torch.no_grad()
def log_images(self, inputs: torch.Tensor,
reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
# Non patch-discriminator
return dict()
logits_fake = self.discriminator(reconstructions.contiguous().detach())
# -> (b, 1, h, w)
# parameters for colormapping
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
cmap = colormaps['PiYG'] # diverging colormap
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
"""(b, 1, ...) -> (b, 3, ...)"""
logits = (logits + high) / (2 * high)
logits_np = cmap(
logits.cpu().numpy())[..., :3] # truncate alpha channel
# -> (b, 1, ..., 3)
logits = torch.from_numpy(logits_np).to(logits.device)
return rearrange(logits, 'b 1 ... c -> b c ...')
logits_real = torch.nn.functional.interpolate(
logits_real,
size=inputs.shape[-2:],
mode='nearest',
antialias=False,
)
logits_fake = torch.nn.functional.interpolate(
logits_fake,
size=reconstructions.shape[-2:],
mode='nearest',
antialias=False,
)
# alpha value of logits for overlay
alpha_real = torch.abs(logits_real) / high
alpha_fake = torch.abs(logits_fake) / high
# -> (b, 1, h, w) in range [0, 0.5]
# alpha value of lines don't really matter, since the values are the same
# for both images and logits anyway
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
# -> (1, h, w)
# blend logits and images together
# prepare logits for plotting
logits_real = to_colormap(logits_real)
logits_fake = to_colormap(logits_fake)
# resize logits
# -> (b, 3, h, w)
# make some grids
# add all logits to one plot
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
# I just love how torchvision calls the number of columns `nrow`
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
# -> (3, h, w)
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5,
nrow=4)
grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions +
0.5,
nrow=4)
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
# -> (3, h, w) in range [0, 1]
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
# Create labeled colorbar
dpi = 100
height = 128 / dpi
width = grid_logits.shape[2] / dpi
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
plt.colorbar(
img,
cax=ax,
orientation='horizontal',
fraction=0.9,
aspect=width / height,
pad=0.0,
)
img.set_visible(False)
fig.tight_layout()
fig.canvas.draw()
# manually convert figure to numpy
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
cbar = rearrange(cbar, 'h w c -> c h w').to(grid_logits.device)
# Add colorbar to plot
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
blended_grid = torch.cat((grid_blend, cbar), dim=1)
return {
'vis_logits': 2 * annotated_grid[None, ...] - 1,
'vis_logits_blended': 2 * blended_grid[None, ...] - 1,
}
def calculate_adaptive_weight(self, nll_loss: torch.Tensor,
g_loss: torch.Tensor,
last_layer: torch.Tensor) -> torch.Tensor:
nll_grads = torch.autograd.grad(nll_loss,
last_layer,
retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
*, # added because I changed the order here
regularization_log: Dict[str, torch.Tensor],
optimizer_idx: int,
global_step: int,
last_layer: torch.Tensor,
split: str = 'train',
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(inputs,
reconstructions.shape[2:],
mode='bicubic',
antialias=True)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, 'b c t h w -> (b t) c h w'),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() -
reconstructions.contiguous())
if self.perceptual_weight > 0:
frame_indices = torch.randn(
(inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
from sgm.modules.autoencoding.losses.video_loss import \
pick_video_frame
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
p_loss = self.perceptual_loss(input_frames.contiguous(),
recon_frames.contiguous()).mean()
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
# now the GAN part
if optimizer_idx == 0:
# generator update
if global_step >= self.discriminator_iter_start or not self.training:
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer)
else:
d_weight = torch.tensor(1.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0, requires_grad=True)
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[
k] * regularization_log[k]
if k in self.additional_log_keys:
log[f'{split}/{k}'] = regularization_log[k].detach().float(
).mean()
log.update({
f'{split}/loss/total': loss.clone().detach().mean(),
f'{split}/loss/nll': nll_loss.detach().mean(),
f'{split}/loss/rec': rec_loss.detach().mean(),
f'{split}/loss/percep': p_loss.detach().mean(),
f'{split}/loss/rec': rec_loss.detach().mean(),
f'{split}/loss/g': g_loss.detach().mean(),
f'{split}/scalars/logvar': self.logvar.detach(),
f'{split}/scalars/d_weight': d_weight.detach(),
})
return loss, log
elif optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach())
if global_step >= self.discriminator_iter_start or not self.training:
d_loss = self.disc_factor * self.disc_loss(
logits_real, logits_fake)
else:
d_loss = torch.tensor(0.0, requires_grad=True)
log = {
f'{split}/loss/disc': d_loss.clone().detach().mean(),
f'{split}/logits/real': logits_real.detach().mean(),
f'{split}/logits/fake': logits_fake.detach().mean(),
}
return d_loss, log
else:
raise NotImplementedError(f'Unknown optimizer_idx {optimizer_idx}')
def get_nll_loss(
self,
rec_loss: torch.Tensor,
weights: Optional[Union[float, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(
weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
return nll_loss, weighted_nll_loss
import torch
import torch.nn as nn
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, 'encoder'):
del self.decoder.encoder
def forward(self,
latent_inputs,
latent_predictions,
image_inputs,
split='train'):
log = dict()
loss = (latent_inputs - latent_predictions)**2
log[f'{split}/latent_l2_loss'] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous())
loss = self.latent_weight * loss.mean(
) + self.perceptual_weight * perceptual_loss.mean()
log[f'{split}/perceptual_loss'] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions))
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode='bicubic',
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode='bicubic',
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous())
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean(
)
log[f'{split}/perceptual_loss_on_inputs'] = perceptual_loss2.mean(
).detach()
return loss, log
from math import log2
from typing import Any, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from beartype import beartype
from einops import einsum, rearrange, repeat
from einops.layers.torch import Rearrange
from kornia.filters import filter3d
from sgm.modules.autoencoding.vqvae.movq_enc_3d import (CausalConv3d,
DownSample3D)
from sgm.util import instantiate_from_config
from torch import Tensor
from torch.autograd import grad as torch_grad
from torch.cuda.amp import autocast
from torchvision.models import VGG16_Weights
from ..magvit2_pytorch import FeedForward, LinearSpaceAttention, Residual
from .lpips import LPIPS
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss,
inputs=layer,
grad_outputs=torch.ones_like(loss),
retain_graph=True)[0].detach()
def pick_video_frame(video, frame_indices):
batch, device = video.shape[0], video.device
video = rearrange(video, 'b c f ... -> b f c ...')
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
images = video[batch_indices, frame_indices]
images = rearrange(images, 'b 1 c ... -> b c ...')
return images
def gradient_penalty(images, output):
batch_size = images.shape[0]
gradients = torch_grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return ((gradients.norm(2, dim=1) - 1)**2).mean()
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x, space_only=False, time_only=False):
assert not (space_only and time_only)
f = self.f
if space_only:
f = einsum('i, j -> i j', f, f)
f = rearrange(f, '... -> 1 1 ...')
elif time_only:
f = rearrange(f, 'f -> 1 f 1 1')
else:
f = einsum('i, j, k -> i j k', f, f, f)
f = rearrange(f, '... -> 1 ...')
is_images = x.ndim == 4
if is_images:
x = rearrange(x, 'b c h w -> b c 1 h w')
out = filter3d(x, f, normalized=True)
if is_images:
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
class DiscriminatorBlock(nn.Module):
def __init__(self,
input_channels,
filters,
downsample=True,
antialiased_downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels,
filters,
1,
stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = (nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1)) if downsample else None)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
channels=3,
max_dim=512,
attn_heads=8,
attn_dim_head=32,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
blocks = []
layer_dims = [channels] + [(dim * 4) * (2**i)
for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
attn_blocks = []
image_resolution = min_image_resolution
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan,
heads=linear_attn_heads,
dim_head=linear_attn_dim_head)),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(
map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b'),
)
def forward(self, x):
for block, attn_block in self.blocks:
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class DiscriminatorBlock3D(nn.Module):
def __init__(
self,
input_channels,
filters,
antialiased_downsample=True,
):
super().__init__()
self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2)
self.net = nn.Sequential(
nn.Conv3d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv3d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = nn.Sequential(
Rearrange('b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w',
p1=2,
p2=2,
p3=2),
nn.Conv3d(filters * 8, filters, 1),
)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class DiscriminatorBlock3DWithfirstframe(nn.Module):
def __init__(
self,
input_channels,
filters,
antialiased_downsample=True,
pad_mode='first',
):
super().__init__()
self.downsample_res = DownSample3D(
in_channels=input_channels,
out_channels=filters,
with_conv=True,
compress_time=True,
)
self.net = nn.Sequential(
CausalConv3d(input_channels,
filters,
kernel_size=3,
pad_mode=pad_mode),
leaky_relu(),
CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = DownSample3D(
in_channels=filters,
out_channels=filters,
with_conv=True,
compress_time=True,
)
def forward(self, x):
res = self.downsample_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator3D(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
frame_num,
channels=3,
max_dim=512,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
temporal_num_layers = int(log2(frame_num))
self.temporal_num_layers = temporal_num_layers
layer_dims = [channels] + [(dim * 4) * (2**i)
for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
image_resolution = min_image_resolution
frame_resolution = frame_num
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
if ind < temporal_num_layers:
block = DiscriminatorBlock3D(
in_chan,
out_chan,
antialiased_downsample=antialiased_downsample,
)
blocks.append(block)
frame_resolution //= 2
else:
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan,
heads=linear_attn_heads,
dim_head=linear_attn_dim_head)),
Residual(
FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(
map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b'),
)
def forward(self, x):
for i, layer in enumerate(self.blocks):
if i < self.temporal_num_layers:
x = layer(x)
if i == self.temporal_num_layers - 1:
x = rearrange(x, 'b c f h w -> (b f) c h w')
else:
block, attn_block = layer
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class Discriminator3DWithfirstframe(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
frame_num,
channels=3,
max_dim=512,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
temporal_num_layers = int(log2(frame_num))
self.temporal_num_layers = temporal_num_layers
layer_dims = [channels] + [(dim * 4) * (2**i)
for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
image_resolution = min_image_resolution
frame_resolution = frame_num
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
if ind < temporal_num_layers:
block = DiscriminatorBlock3DWithfirstframe(
in_chan,
out_chan,
antialiased_downsample=antialiased_downsample,
)
blocks.append(block)
frame_resolution //= 2
else:
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan,
heads=linear_attn_heads,
dim_head=linear_attn_dim_head)),
Residual(
FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(
map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b'),
)
def forward(self, x):
for i, layer in enumerate(self.blocks):
if i < self.temporal_num_layers:
x = layer(x)
if i == self.temporal_num_layers - 1:
x = x.mean(dim=2)
# x = rearrange(x, "b c f h w -> (b f) c h w")
else:
block, attn_block = layer
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class VideoAutoencoderLoss(nn.Module):
def __init__(
self,
disc_start,
perceptual_weight=1,
adversarial_loss_weight=0,
multiscale_adversarial_loss_weight=0,
grad_penalty_loss_weight=0,
quantizer_aux_loss_weight=0,
vgg_weights=VGG16_Weights.DEFAULT,
discr_kwargs=None,
discr_3d_kwargs=None,
):
super().__init__()
self.disc_start = disc_start
self.perceptual_weight = perceptual_weight
self.adversarial_loss_weight = adversarial_loss_weight
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
self.grad_penalty_loss_weight = grad_penalty_loss_weight
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
if self.perceptual_weight > 0:
self.perceptual_model = LPIPS().eval()
# self.vgg = torchvision.models.vgg16(pretrained = True)
# self.vgg.requires_grad_(False)
# if self.adversarial_loss_weight > 0:
# self.discr = Discriminator(**discr_kwargs)
# else:
# self.discr = None
# if self.multiscale_adversarial_loss_weight > 0:
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
# else:
# self.multiscale_discrs = None
if discr_kwargs is not None:
self.discr = Discriminator(**discr_kwargs)
else:
self.discr = None
if discr_3d_kwargs is not None:
# self.discr_3d = Discriminator3D(**discr_3d_kwargs)
self.discr_3d = instantiate_from_config(discr_3d_kwargs)
else:
self.discr_3d = None
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
self.register_buffer('zero', torch.tensor(0.0), persistent=False)
def get_trainable_params(self) -> Any:
params = []
if self.discr is not None:
params += list(self.discr.parameters())
if self.discr_3d is not None:
params += list(self.discr_3d.parameters())
# if self.multiscale_discrs is not None:
# for discr in self.multiscale_discrs:
# params += list(discr.parameters())
return params
def get_trainable_parameters(self) -> Any:
return self.get_trainable_params()
def forward(
self,
inputs,
reconstructions,
optimizer_idx,
global_step,
aux_losses=None,
last_layer=None,
split='train',
):
batch, channels, frames = inputs.shape[:3]
if optimizer_idx == 0:
recon_loss = F.mse_loss(inputs, reconstructions)
if self.perceptual_weight > 0:
frame_indices = torch.randn(
(batch, frames)).topk(1, dim=-1).indices
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
perceptual_loss = self.perceptual_model(
input_frames.contiguous(),
recon_frames.contiguous()).mean()
else:
perceptual_loss = self.zero
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
gen_loss = self.zero
adaptive_weight = 0
else:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# fake_logits = self.discr(recon_video_frames)
fake_logits = self.discr_3d(reconstructions)
gen_loss = hinge_gen_loss(fake_logits)
adaptive_weight = 1
if self.perceptual_weight > 0 and last_layer is not None:
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_layer).norm(p=2)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
gen_loss, last_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():
adaptive_weight = 1
# multiscale discriminator losses
# multiscale_gen_losses = []
# multiscale_gen_adaptive_weights = []
# if self.multiscale_adversarial_loss_weight > 0:
# if not exists(recon_video_frames):
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# for discr in self.multiscale_discrs:
# fake_logits = recon_video_frames
# multiscale_gen_loss = hinge_gen_loss(fake_logits)
# multiscale_gen_losses.append(multiscale_gen_loss)
# multiscale_adaptive_weight = 1.
# if exists(norm_grad_wrt_perceptual_loss):
# norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
# multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
# multiscale_adaptive_weight.clamp_(max = 1e3)
# multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
# weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
# else:
# weighted_multiscale_gen_losses = self.zero
if aux_losses is None:
aux_losses = self.zero
total_loss = (recon_loss +
aux_losses * self.quantizer_aux_loss_weight +
perceptual_loss * self.perceptual_weight +
gen_loss * self.adversarial_loss_weight)
# gen_loss * adaptive_weight * self.adversarial_loss_weight + \
# weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
log = {
f'{split}/total_loss': total_loss.detach(),
f'{split}/recon_loss': recon_loss.detach(),
f'{split}/perceptual_loss': perceptual_loss.detach(),
f'{split}/gen_loss': gen_loss.detach(),
f'{split}/aux_losses': aux_losses.detach(),
# "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
f'{split}/adaptive_weight': adaptive_weight,
# "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
}
return total_loss, log
if optimizer_idx == 1:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# real = pick_video_frame(inputs, frame_indices)
# fake = pick_video_frame(reconstructions, frame_indices)
# apply_gradient_penalty = self.grad_penalty_loss_weight > 0
# if apply_gradient_penalty:
# real = real.requires_grad_()
# real_logits = self.discr(real)
# fake_logits = self.discr(fake.detach())
apply_gradient_penalty = self.grad_penalty_loss_weight > 0
if apply_gradient_penalty:
inputs = inputs.requires_grad_()
real_logits = self.discr_3d(inputs)
fake_logits = self.discr_3d(reconstructions.detach())
discr_loss = hinge_discr_loss(fake_logits, real_logits)
# # multiscale discriminators
# multiscale_discr_losses = []
# if self.multiscale_adversarial_loss_weight > 0:
# for discr in self.multiscale_discrs:
# multiscale_real_logits = discr(inputs)
# multiscale_fake_logits = discr(reconstructions.detach())
# multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
# multiscale_discr_losses.append(multiscale_discr_loss)
# else:
# multiscale_discr_losses.append(self.zero)
# gradient penalty
if apply_gradient_penalty:
# gradient_penalty_loss = gradient_penalty(real, real_logits)
gradient_penalty_loss = gradient_penalty(inputs, real_logits)
else:
gradient_penalty_loss = self.zero
total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss
# self.grad_penalty_loss_weight * gradient_penalty_loss + \
# sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
log = {
f'{split}/total_disc_loss': total_loss.detach(),
f'{split}/discr_loss': discr_loss.detach(),
f'{split}/grad_penalty_loss': gradient_penalty_loss.detach(),
# "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
f'{split}/logits_real': real_logits.detach().mean(),
f'{split}/logits_fake': fake_logits.detach().mean(),
}
return total_loss, log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment