Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode="softmax",
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
if not XFORMERS_IS_AVAILABLE:
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
kwargs = {"x": x}
if context is not None:
kwargs.update({"context": context})
if additional_tokens is not None:
kwargs.update({"additional_tokens": additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode="softmax",
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type="softmax",
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
__all__ = [
"GeneralLPIPSWithDiscriminator",
"LatentLPIPS",
]
from .discriminator_loss import GeneralLPIPSWithDiscriminator
from .lpips import LatentLPIPS
from .video_loss import VideoAutoencoderLoss
from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision
from einops import rearrange
from matplotlib import colormaps
from matplotlib import pyplot as plt
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, Dict[str, float]] = None,
additional_log_keys: Optional[List[str]] = None,
discriminator_config: Optional[Dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss "
f"calculation, the LPIPS loss will be applied to each frame "
f"independently."
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
self.learn_logvar = learn_logvar
discriminator_config = default(
discriminator_config,
{
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
"params": {
"input_nc": disc_in_channels,
"n_layers": disc_num_layers,
"use_actnorm": False,
},
},
)
self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
self.forward_keys = [
"optimizer_idx",
"global_step",
"last_layer",
"split",
"regularization_log",
]
self.additional_log_keys = set(default(additional_log_keys, []))
self.additional_log_keys.update(set(self.regularization_weights.keys()))
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
if self.learn_logvar:
yield self.logvar
yield from ()
@torch.no_grad()
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
# Non patch-discriminator
return dict()
logits_fake = self.discriminator(reconstructions.contiguous().detach())
# -> (b, 1, h, w)
# parameters for colormapping
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
cmap = colormaps["PiYG"] # diverging colormap
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
"""(b, 1, ...) -> (b, 3, ...)"""
logits = (logits + high) / (2 * high)
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
# -> (b, 1, ..., 3)
logits = torch.from_numpy(logits_np).to(logits.device)
return rearrange(logits, "b 1 ... c -> b c ...")
logits_real = torch.nn.functional.interpolate(
logits_real,
size=inputs.shape[-2:],
mode="nearest",
antialias=False,
)
logits_fake = torch.nn.functional.interpolate(
logits_fake,
size=reconstructions.shape[-2:],
mode="nearest",
antialias=False,
)
# alpha value of logits for overlay
alpha_real = torch.abs(logits_real) / high
alpha_fake = torch.abs(logits_fake) / high
# -> (b, 1, h, w) in range [0, 0.5]
# alpha value of lines don't really matter, since the values are the same
# for both images and logits anyway
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
# -> (1, h, w)
# blend logits and images together
# prepare logits for plotting
logits_real = to_colormap(logits_real)
logits_fake = to_colormap(logits_fake)
# resize logits
# -> (b, 3, h, w)
# make some grids
# add all logits to one plot
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
# I just love how torchvision calls the number of columns `nrow`
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
# -> (3, h, w)
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
# -> (3, h, w) in range [0, 1]
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
# Create labeled colorbar
dpi = 100
height = 128 / dpi
width = grid_logits.shape[2] / dpi
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
plt.colorbar(
img,
cax=ax,
orientation="horizontal",
fraction=0.9,
aspect=width / height,
pad=0.0,
)
img.set_visible(False)
fig.tight_layout()
fig.canvas.draw()
# manually convert figure to numpy
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
# Add colorbar to plot
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
blended_grid = torch.cat((grid_blend, cbar), dim=1)
return {
"vis_logits": 2 * annotated_grid[None, ...] - 1,
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
}
def calculate_adaptive_weight(
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
) -> torch.Tensor:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
*, # added because I changed the order here
regularization_log: Dict[str, torch.Tensor],
optimizer_idx: int,
global_step: int,
last_layer: torch.Tensor,
split: str = "train",
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
# now the GAN part
if optimizer_idx == 0:
# generator update
if global_step >= self.discriminator_iter_start or not self.training:
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
else:
d_weight = torch.tensor(1.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0, requires_grad=True)
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
if k in self.additional_log_keys:
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
log.update(
{
f"{split}/loss/total": loss.clone().detach().mean(),
f"{split}/loss/nll": nll_loss.detach().mean(),
f"{split}/loss/rec": rec_loss.detach().mean(),
f"{split}/loss/percep": p_loss.detach().mean(),
f"{split}/loss/rec": rec_loss.detach().mean(),
f"{split}/loss/g": g_loss.detach().mean(),
f"{split}/scalars/logvar": self.logvar.detach(),
f"{split}/scalars/d_weight": d_weight.detach(),
}
)
return loss, log
elif optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
if global_step >= self.discriminator_iter_start or not self.training:
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
else:
d_loss = torch.tensor(0.0, requires_grad=True)
log = {
f"{split}/loss/disc": d_loss.clone().detach().mean(),
f"{split}/logits/real": logits_real.detach().mean(),
f"{split}/logits/fake": logits_fake.detach().mean(),
}
return d_loss, log
else:
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
def get_nll_loss(
self,
rec_loss: torch.Tensor,
weights: Optional[Union[float, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
return nll_loss, weighted_nll_loss
import torch
import torch.nn as nn
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log
from typing import Any, Union
from math import log2
from beartype import beartype
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad as torch_grad
from torch.cuda.amp import autocast
import torchvision
from torchvision.models import VGG16_Weights
from einops import rearrange, einsum, repeat
from einops.layers.torch import Rearrange
from kornia.filters import filter3d
from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention
from .lpips import LPIPS
from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D
from sgm.util import instantiate_from_config
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
def pick_video_frame(video, frame_indices):
batch, device = video.shape[0], video.device
video = rearrange(video, "b c f ... -> b f c ...")
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, "b -> b 1")
images = video[batch_indices, frame_indices]
images = rearrange(images, "b 1 c ... -> b c ...")
return images
def gradient_penalty(images, output):
batch_size = images.shape[0]
gradients = torch_grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer("f", f)
def forward(self, x, space_only=False, time_only=False):
assert not (space_only and time_only)
f = self.f
if space_only:
f = einsum("i, j -> i j", f, f)
f = rearrange(f, "... -> 1 1 ...")
elif time_only:
f = rearrange(f, "f -> 1 f 1 1")
else:
f = einsum("i, j, k -> i j k", f, f, f)
f = rearrange(f, "... -> 1 ...")
is_images = x.ndim == 4
if is_images:
x = rearrange(x, "b c h w -> b c 1 h w")
out = filter3d(x, f, normalized=True)
if is_images:
out = rearrange(out, "b c 1 h w -> b c h w")
return out
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = (
nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
)
if downsample
else None
)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
channels=3,
max_dim=512,
attn_heads=8,
attn_dim_head=32,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
blocks = []
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
attn_blocks = []
image_resolution = min_image_resolution
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
)
attn_block = nn.Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange("b ... -> b (...)"),
nn.Linear(latent_dim, 1),
Rearrange("b 1 -> b"),
)
def forward(self, x):
for block, attn_block in self.blocks:
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class DiscriminatorBlock3D(nn.Module):
def __init__(
self,
input_channels,
filters,
antialiased_downsample=True,
):
super().__init__()
self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2)
self.net = nn.Sequential(
nn.Conv3d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv3d(filters, filters, 3, padding=1),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = nn.Sequential(
Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2),
nn.Conv3d(filters * 8, filters, 1),
)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class DiscriminatorBlock3DWithfirstframe(nn.Module):
def __init__(
self,
input_channels,
filters,
antialiased_downsample=True,
pad_mode="first",
):
super().__init__()
self.downsample_res = DownSample3D(
in_channels=input_channels,
out_channels=filters,
with_conv=True,
compress_time=True,
)
self.net = nn.Sequential(
CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode),
leaky_relu(),
CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode),
leaky_relu(),
)
self.maybe_blur = Blur() if antialiased_downsample else None
self.downsample = DownSample3D(
in_channels=filters,
out_channels=filters,
with_conv=True,
compress_time=True,
)
def forward(self, x):
res = self.downsample_res(x)
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
x = self.maybe_blur(x, space_only=True)
x = self.downsample(x)
x = (x + res) * (2**-0.5)
return x
class Discriminator3D(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
frame_num,
channels=3,
max_dim=512,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
temporal_num_layers = int(log2(frame_num))
self.temporal_num_layers = temporal_num_layers
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
image_resolution = min_image_resolution
frame_resolution = frame_num
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
if ind < temporal_num_layers:
block = DiscriminatorBlock3D(
in_chan,
out_chan,
antialiased_downsample=antialiased_downsample,
)
blocks.append(block)
frame_resolution //= 2
else:
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange("b ... -> b (...)"),
nn.Linear(latent_dim, 1),
Rearrange("b 1 -> b"),
)
def forward(self, x):
for i, layer in enumerate(self.blocks):
if i < self.temporal_num_layers:
x = layer(x)
if i == self.temporal_num_layers - 1:
x = rearrange(x, "b c f h w -> (b f) c h w")
else:
block, attn_block = layer
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class Discriminator3DWithfirstframe(nn.Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
frame_num,
channels=3,
max_dim=512,
linear_attn_dim_head=8,
linear_attn_heads=16,
ff_mult=4,
antialiased_downsample=False,
):
super().__init__()
image_size = pair(image_size)
min_image_resolution = min(image_size)
num_layers = int(log2(min_image_resolution) - 2)
temporal_num_layers = int(log2(frame_num))
self.temporal_num_layers = temporal_num_layers
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
image_resolution = min_image_resolution
frame_resolution = frame_num
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
if ind < temporal_num_layers:
block = DiscriminatorBlock3DWithfirstframe(
in_chan,
out_chan,
antialiased_downsample=antialiased_downsample,
)
blocks.append(block)
frame_resolution //= 2
else:
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
blocks.append(nn.ModuleList([block, attn_block]))
image_resolution //= 2
self.blocks = nn.ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2**num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding=1),
leaky_relu(),
Rearrange("b ... -> b (...)"),
nn.Linear(latent_dim, 1),
Rearrange("b 1 -> b"),
)
def forward(self, x):
for i, layer in enumerate(self.blocks):
if i < self.temporal_num_layers:
x = layer(x)
if i == self.temporal_num_layers - 1:
x = x.mean(dim=2)
# x = rearrange(x, "b c f h w -> (b f) c h w")
else:
block, attn_block = layer
x = block(x)
x = attn_block(x)
return self.to_logits(x)
class VideoAutoencoderLoss(nn.Module):
def __init__(
self,
disc_start,
perceptual_weight=1,
adversarial_loss_weight=0,
multiscale_adversarial_loss_weight=0,
grad_penalty_loss_weight=0,
quantizer_aux_loss_weight=0,
vgg_weights=VGG16_Weights.DEFAULT,
discr_kwargs=None,
discr_3d_kwargs=None,
):
super().__init__()
self.disc_start = disc_start
self.perceptual_weight = perceptual_weight
self.adversarial_loss_weight = adversarial_loss_weight
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
self.grad_penalty_loss_weight = grad_penalty_loss_weight
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
if self.perceptual_weight > 0:
self.perceptual_model = LPIPS().eval()
# self.vgg = torchvision.models.vgg16(pretrained = True)
# self.vgg.requires_grad_(False)
# if self.adversarial_loss_weight > 0:
# self.discr = Discriminator(**discr_kwargs)
# else:
# self.discr = None
# if self.multiscale_adversarial_loss_weight > 0:
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
# else:
# self.multiscale_discrs = None
if discr_kwargs is not None:
self.discr = Discriminator(**discr_kwargs)
else:
self.discr = None
if discr_3d_kwargs is not None:
# self.discr_3d = Discriminator3D(**discr_3d_kwargs)
self.discr_3d = instantiate_from_config(discr_3d_kwargs)
else:
self.discr_3d = None
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
def get_trainable_params(self) -> Any:
params = []
if self.discr is not None:
params += list(self.discr.parameters())
if self.discr_3d is not None:
params += list(self.discr_3d.parameters())
# if self.multiscale_discrs is not None:
# for discr in self.multiscale_discrs:
# params += list(discr.parameters())
return params
def get_trainable_parameters(self) -> Any:
return self.get_trainable_params()
def forward(
self,
inputs,
reconstructions,
optimizer_idx,
global_step,
aux_losses=None,
last_layer=None,
split="train",
):
batch, channels, frames = inputs.shape[:3]
if optimizer_idx == 0:
recon_loss = F.mse_loss(inputs, reconstructions)
if self.perceptual_weight > 0:
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
else:
perceptual_loss = self.zero
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
gen_loss = self.zero
adaptive_weight = 0
else:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# fake_logits = self.discr(recon_video_frames)
fake_logits = self.discr_3d(reconstructions)
gen_loss = hinge_gen_loss(fake_logits)
adaptive_weight = 1
if self.perceptual_weight > 0 and last_layer is not None:
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():
adaptive_weight = 1
# multiscale discriminator losses
# multiscale_gen_losses = []
# multiscale_gen_adaptive_weights = []
# if self.multiscale_adversarial_loss_weight > 0:
# if not exists(recon_video_frames):
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
# for discr in self.multiscale_discrs:
# fake_logits = recon_video_frames
# multiscale_gen_loss = hinge_gen_loss(fake_logits)
# multiscale_gen_losses.append(multiscale_gen_loss)
# multiscale_adaptive_weight = 1.
# if exists(norm_grad_wrt_perceptual_loss):
# norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
# multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
# multiscale_adaptive_weight.clamp_(max = 1e3)
# multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
# weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
# else:
# weighted_multiscale_gen_losses = self.zero
if aux_losses is None:
aux_losses = self.zero
total_loss = (
recon_loss
+ aux_losses * self.quantizer_aux_loss_weight
+ perceptual_loss * self.perceptual_weight
+ gen_loss * self.adversarial_loss_weight
)
# gen_loss * adaptive_weight * self.adversarial_loss_weight + \
# weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
log = {
"{}/total_loss".format(split): total_loss.detach(),
"{}/recon_loss".format(split): recon_loss.detach(),
"{}/perceptual_loss".format(split): perceptual_loss.detach(),
"{}/gen_loss".format(split): gen_loss.detach(),
"{}/aux_losses".format(split): aux_losses.detach(),
# "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
"{}/adaptive_weight".format(split): adaptive_weight,
# "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
}
return total_loss, log
if optimizer_idx == 1:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# real = pick_video_frame(inputs, frame_indices)
# fake = pick_video_frame(reconstructions, frame_indices)
# apply_gradient_penalty = self.grad_penalty_loss_weight > 0
# if apply_gradient_penalty:
# real = real.requires_grad_()
# real_logits = self.discr(real)
# fake_logits = self.discr(fake.detach())
apply_gradient_penalty = self.grad_penalty_loss_weight > 0
if apply_gradient_penalty:
inputs = inputs.requires_grad_()
real_logits = self.discr_3d(inputs)
fake_logits = self.discr_3d(reconstructions.detach())
discr_loss = hinge_discr_loss(fake_logits, real_logits)
# # multiscale discriminators
# multiscale_discr_losses = []
# if self.multiscale_adversarial_loss_weight > 0:
# for discr in self.multiscale_discrs:
# multiscale_real_logits = discr(inputs)
# multiscale_fake_logits = discr(reconstructions.detach())
# multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
# multiscale_discr_losses.append(multiscale_discr_loss)
# else:
# multiscale_discr_losses.append(self.zero)
# gradient penalty
if apply_gradient_penalty:
# gradient_penalty_loss = gradient_penalty(real, real_logits)
gradient_penalty_loss = gradient_penalty(inputs, real_logits)
else:
gradient_penalty_loss = self.zero
total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss
# self.grad_penalty_loss_weight * gradient_penalty_loss + \
# sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
log = {
"{}/total_disc_loss".format(split): total_loss.detach(),
"{}/discr_loss".format(split): discr_loss.detach(),
"{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(),
# "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
"{}/logits_real".format(split): real_logits.detach().mean(),
"{}/logits_fake".format(split): fake_logits.detach().mean(),
}
return total_loss, log
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
from ..util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment