Commit 23904d54 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into conversion-scripts

parents 32b93da8 c691bb2f
...@@ -88,7 +88,7 @@ _deps = [ ...@@ -88,7 +88,7 @@ _deps = [
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"tensorboard", "tensorboard",
"modelcards=0.1.4" "modelcards==0.1.4"
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
......
...@@ -14,4 +14,5 @@ deps = { ...@@ -14,4 +14,5 @@ deps = {
"requests": "requests", "requests": "requests",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"tensorboard": "tensorboard", "tensorboard": "tensorboard",
"modelcards": "modelcards==0.1.4",
} }
import string
from abc import abstractmethod from abc import abstractmethod
from functools import partial
import numpy as np import numpy as np
import torch import torch
...@@ -79,18 +79,25 @@ class Upsample(nn.Module): ...@@ -79,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose: if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
...@@ -103,7 +110,10 @@ class Upsample(nn.Module): ...@@ -103,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv: if self.use_conv:
if self.name == "conv":
x = self.conv(x) x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -135,6 +145,8 @@ class Downsample(nn.Module): ...@@ -135,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else: else:
self.op = conv self.op = conv
...@@ -146,6 +158,8 @@ class Downsample(nn.Module): ...@@ -146,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else: else:
return self.op(x) return self.op(x)
...@@ -162,110 +176,7 @@ class Downsample(nn.Module): ...@@ -162,110 +176,7 @@ class Downsample(nn.Module):
# RESNETS # RESNETS
# unet_glide.py & unet_ldm.py # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
# unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
self, self,
...@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module): ...@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module):
pre_norm=True, pre_norm=True,
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False, overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
): ):
super().__init__() super().__init__()
self.pre_norm = pre_norm self.pre_norm = pre_norm
...@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module): ...@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module):
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm: if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
...@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module): ...@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module):
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish": if non_linearity == "swish":
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
elif non_linearity == "mish": elif non_linearity == "mish":
self.nonlinearity = Mish() self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts: if self.overwrite_for_grad_tts:
dim = in_channels dim = in_channels
dim_out = out_channels dim_out = out_channels
...@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module): ...@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else: else:
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self): def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data self.conv1.weight.data = self.block1.block[0].weight.data
...@@ -343,27 +308,64 @@ class ResnetBlock(nn.Module): ...@@ -343,27 +308,64 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.res_conv.weight.data self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None): def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x h = x
h = h * mask if mask is not None else h h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h) h = self.conv1(h)
if not self.pre_norm: if not self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = h * mask if mask is not None else h h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = h * mask if mask is not None else h h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
...@@ -374,13 +376,10 @@ class ResnetBlock(nn.Module): ...@@ -374,13 +376,10 @@ class ResnetBlock(nn.Module):
if not self.pre_norm: if not self.pre_norm:
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = h * mask if mask is not None else h h = h * mask
x = x * mask if mask is not None else x x = x * mask
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return x + h return x + h
...@@ -394,10 +393,6 @@ class Block(torch.nn.Module): ...@@ -394,10 +393,6 @@ class Block(torch.nn.Module):
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
) )
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
# unet_score_estimation.py # unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module): class ResnetBlockBigGANpp(nn.Module):
...@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module):
self.fir = fir self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch) if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.act = act self.act = act
...@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
if self.fir: h = self.upsample(h)
h = upsample_2d(h, self.fir_kernel, factor=2) x = self.upsample(x)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down: elif self.down:
if self.fir: h = self.downsample(h)
h = downsample_2d(h, self.fir_kernel, factor=2) x = self.downsample(x)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
...@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_score_estimation.py
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py # unet_rl.py
class ResidualTemporalBlock(nn.Module): class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
...@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module): ...@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""1x1 convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
def default_init(scale=1.0): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
...@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2 denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
...@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def _setup_kernel(k): def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1: if k.ndim == 1:
...@@ -829,17 +720,3 @@ def _setup_kernel(k): ...@@ -829,17 +720,3 @@ def _setup_kernel(k):
assert k.ndim == 2 assert k.ndim == 2
assert k.shape[0] == k.shape[1] assert k.shape[0] == k.shape[1]
return k return k
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
...@@ -34,48 +34,6 @@ def Normalize(in_channels): ...@@ -34,48 +34,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class UNetModel(ModelMixin, ConfigMixin): class UNetModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
......
...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin ...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -29,19 +29,6 @@ def convert_module_to_f32(l): ...@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
l.bias.data = l.bias.data.float() l.bias.data = l.bias.data.float()
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
...@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, encoder_out=None): def forward(self, x, emb, encoder_out=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, AttentionBlock): elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out) x = layer(x, encoder_out)
...@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=mult * model_channels,
dropout, dropout=dropout,
out_channels=int(mult * model_channels), temb_channels=time_embed_dim,
dims=dims, eps=1e-5,
use_checkpoint=use_checkpoint, non_linearity="silu",
use_scale_shift_norm=use_scale_shift_norm, time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
) )
] ]
ch = int(mult * model_channels) ch = int(mult * model_channels)
...@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim,
dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
down=True, down=True,
) )
if resblock_updown if resblock_updown
...@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch self._feature_size += ch
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, dropout=dropout,
dropout, temb_channels=time_embed_dim,
dims=dims, eps=1e-5,
use_checkpoint=use_checkpoint, non_linearity="silu",
use_scale_shift_norm=use_scale_shift_norm, time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
...@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=transformer_dim, encoder_channels=transformer_dim,
), ),
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, dropout=dropout,
dropout, temb_channels=time_embed_dim,
dims=dims, eps=1e-5,
use_checkpoint=use_checkpoint, non_linearity="silu",
use_scale_shift_norm=use_scale_shift_norm, time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
), ),
) )
self._feature_size += ch self._feature_size += ch
...@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( ResnetBlock(
ch + ich, in_channels=ch + ich,
time_embed_dim, out_channels=model_channels * mult,
dropout, dropout=dropout,
out_channels=int(model_channels * mult), temb_channels=time_embed_dim,
dims=dims, eps=1e-5,
use_checkpoint=use_checkpoint, non_linearity="silu",
use_scale_shift_norm=use_scale_shift_norm, time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
) overwrite_for_glide=True,
),
] ]
ch = int(model_channels * mult) ch = int(model_channels * mult)
if ds in attention_resolutions: if ds in attention_resolutions:
...@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim,
dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
up=True, up=True,
) )
if resblock_updown if resblock_updown
......
...@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin ...@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
# from .resnet import ResBlock
def exists(val): def exists(val):
...@@ -75,182 +78,6 @@ def Normalize(in_channels): ...@@ -75,182 +78,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
# hidden_dim = dim_head * heads
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
# def forward(self, x):
# b, c, h, w = x.shape
# qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# import ipdb; ipdb.set_trace()
# k = k.softmax(dim=-1)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# w_ = torch.einsum("bij,bjk->bik", q, k)
#
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# h_ = self.proj_out(h_)
#
# return x + h_
#
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
def convert_module_to_f16(l): def convert_module_to_f16(l):
""" """
Convert primitive modules to float16. Convert primitive modules to float16.
...@@ -271,19 +98,6 @@ def convert_module_to_f32(l): ...@@ -271,19 +98,6 @@ def convert_module_to_f32(l):
l.bias.data = l.bias.data.float() l.bias.data = l.bias.data.float()
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
...@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0): ...@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
A sequential module that passes timestep embeddings to the children that support it as an extra input. A sequential module that passes timestep embeddings to the children that support it as an extra input.
...@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context)
...@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
def count_flops_attn(model, _x, y): def count_flops_attn(model, _x, y):
""" """
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
...@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
...@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
) )
) )
ch = out_ch ch = out_ch
...@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
...@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if not use_spatial_transformer if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
) )
self._feature_size += ch self._feature_size += ch
...@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( ResnetBlock(
ch + ich, in_channels=ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
) non_linearity="silu",
overwrite_for_ldm=True,
),
] ]
ch = model_channels * mult ch = model_channels * mult
if ds in attention_resolutions: if ds in attention_resolutions:
...@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
...@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
return self.out(h) return self.out(h)
class EncoderUNetModel(nn.Module): class SpatialTransformer(nn.Module):
""" """
The half UNet model with attention and timestep embedding. For usage, see UNet. Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
""" """
def __init__( def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs,
):
super().__init__() super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels inner_dim = n_heads * d_head
self.out_channels = out_channels self.norm = Normalize(in_channels)
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4 self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] [
) BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
self._feature_size = model_channels for d in range(depth)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
] ]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
) )
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential( self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self): def forward(self, x, context=None):
""" # note: if no context is given, cross-attention defaults to self-attention
Convert the torso of the model to float16. b, c, h, w = x.shape
""" x_in = x
self.input_blocks.apply(convert_module_to_f16) x = self.norm(x)
self.middle_block.apply(convert_module_to_f16) x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps): class BasicTransformerBlock(nn.Module):
""" def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch super().__init__()
of timesteps. :return: an [N x K] Tensor of outputs. self.attn1 = CrossAttention(
""" query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
emb = self.time_embed( ) # is a self-attention
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
) self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
results = [] def forward(self, x, context=None):
h = x.type(self.dtype) x = self.attn1(self.norm1(x)) + x
for module in self.input_blocks: x = self.attn2(self.norm2(x), context=context) + x
h = module(h, emb) x = self.ff(self.norm3(x)) + x
if self.pool.startswith("spatial"): return x
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"): class CrossAttention(nn.Module):
results.append(h.type(x.dtype).mean(dim=(2, 3))) def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
h = torch.cat(results, axis=-1) super().__init__()
return self.out(h) inner_dim = dim_head * heads
else: context_dim = default(context_dim, query_dim)
h = h.type(x.dtype)
return self.out(h) self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import ResidualTemporalBlock from .resnet import Downsample, ResidualTemporalBlock, Upsample
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
...@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module): ...@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return get_timestep_embedding(x, self.dim) return get_timestep_embedding(x, self.dim)
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class RearrangeDim(nn.Module): class RearrangeDim(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(), Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
] ]
) )
) )
...@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(), Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
] ]
) )
) )
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import math import math
import string
import numpy as np import numpy as np
import torch import torch
...@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin ...@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def _setup_kernel(k):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): k /= np.sum(k)
_, channel, in_h, in_w = input.shape assert k.ndim == 2
input = input.reshape(-1, in_h, in_w, 1) assert k.shape[0] == k.shape[1]
return k
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
# Function ported from StyleGAN2
def get_weight(module, shape, weight_var="weight", kernel_init=None):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return module.param(weight_var, kernel_init, shape)
class Conv2d(nn.Module):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def __init__(
self,
in_ch,
out_ch,
kernel,
up=False,
down=False,
resample_kernel=(1, 3, 3, 1),
use_bias=True,
kernel_init=None,
):
super().__init__()
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
if kernel_init is not None:
self.weight.data = kernel_init(self.weight.data.shape)
if use_bias:
self.bias = nn.Parameter(torch.zeros(out_ch))
self.up = up
self.down = down
self.resample_kernel = resample_kernel
self.kernel = kernel
self.use_bias = use_bias
def forward(self, x):
if self.up:
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
elif self.down:
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
else:
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
if self.use_bias:
x = x + self.bias.reshape(1, -1, 1, 1)
return x
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
def upsample_conv_2d(x, w, k=None, factor=2, gain=1): def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. """Fused `upsample_2d()` followed by `Conv2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions. # Determine data dimensions.
stride = [1, 1, factor, factor] stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = ( output_padding = (
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
) )
assert output_padding[0] >= 0 and output_padding[1] >= 0 assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = _shape(x, 1) // inC num_groups = x.shape[1] // inC
# Transpose weights. # Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
...@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_2d(x, w, k=None, factor=2, gain=1): def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1): ...@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return F.conv2d(x, w, stride=s, padding=0) return F.conv2d(x, w, stride=s, padding=0)
def _setup_kernel(k): def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _shape(x, dim):
return x.shape[dim]
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def get_act(nonlinearity):
"""Get activation functions from the config file."""
if nonlinearity.lower() == "elu":
return nn.ELU()
elif nonlinearity.lower() == "relu":
return nn.ReLU()
elif nonlinearity.lower() == "lrelu":
return nn.LeakyReLU(negative_slope=0.2)
elif nonlinearity.lower() == "swish":
return nn.SiLU()
else:
raise NotImplementedError("activation function does not exist!")
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
...@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2 denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
self.Conv_0 = conv1x1(dim1, dim2) # 1x1 convolution with DDPM initialization.
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
def forward(self, x, y): def forward(self, x, y):
...@@ -413,80 +183,42 @@ class Combine(nn.Module): ...@@ -413,80 +183,42 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class Upsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if not fir: if use_conv:
if with_conv: self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.Conv_0 = conv3x3(in_ch, out_ch) self.use_conv = use_conv
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.with_conv = with_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape if self.use_conv:
if not self.fir: h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = F.interpolate(x, (H * 2, W * 2), "nearest") h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
if self.with_conv:
h = self.Conv_0(h)
else: else:
if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = self.Conv2d_0(x)
return h return h
class Downsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if not fir: if use_conv:
if with_conv: self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.use_conv = use_conv
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape if self.use_conv:
if not self.fir: x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
if self.with_conv: x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else: else:
x = F.avg_pool2d(x, 2, stride=2)
else:
if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
centered=False,
image_size=1024, image_size=1024,
num_channels=3, num_channels=3,
attention_type="ddpm", centered=False,
attn_resolutions=(16,), attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True, conditional=True,
...@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale=16, fourier_scale=16,
init_scale=0.0, init_scale=0.0,
nf=16, nf=16,
nonlinearity="swish",
normalization="GroupNorm",
num_res_blocks=1, num_res_blocks=1,
progressive="output_skip", progressive="output_skip",
progressive_combine="sum", progressive_combine="sum",
progressive_input="input_skip", progressive_input="input_skip",
resamp_with_conv=True, resamp_with_conv=True,
resblock_type="biggan",
scale_by_sigma=True, scale_by_sigma=True,
skip_rescale=True, skip_rescale=True,
continuous=True, continuous=True,
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
centered=centered,
image_size=image_size, image_size=image_size,
num_channels=num_channels, num_channels=num_channels,
attention_type=attention_type, centered=centered,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
ch_mult=ch_mult, ch_mult=ch_mult,
conditional=conditional, conditional=conditional,
...@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale=fourier_scale, fourier_scale=fourier_scale,
init_scale=init_scale, init_scale=init_scale,
nf=nf, nf=nf,
nonlinearity=nonlinearity,
normalization=normalization,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
progressive=progressive, progressive=progressive,
progressive_combine=progressive_combine, progressive_combine=progressive_combine,
progressive_input=progressive_input, progressive_input=progressive_input,
resamp_with_conv=resamp_with_conv, resamp_with_conv=resamp_with_conv,
resblock_type=resblock_type,
scale_by_sigma=scale_by_sigma, scale_by_sigma=scale_by_sigma,
skip_rescale=skip_rescale, skip_rescale=skip_rescale,
continuous=continuous, continuous=continuous,
) )
self.act = act = get_act(nonlinearity) self.act = act = nn.SiLU()
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
...@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.conditional = conditional self.conditional = conditional
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.resblock_type = resblock_type
self.progressive = progressive self.progressive = progressive
self.progressive_input = progressive_input self.progressive_input = progressive_input
self.embedding_type = embedding_type self.embedding_type = embedding_type
...@@ -585,40 +308,31 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -585,40 +308,31 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
raise ValueError(f"embedding type {embedding_type} unknown.") raise ValueError(f"embedding type {embedding_type} unknown.")
if conditional: modules.append(Linear(embed_dim, nf * 4))
modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, use_conv=True)
if resblock_type == "ddpm":
ResnetBlock = functools.partial(
ResnetBlockDDPMpp,
act=act,
dropout=dropout,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
elif resblock_type == "biggan":
ResnetBlock = functools.partial( ResnetBlock = functools.partial(
ResnetBlockBigGANpp, ResnetBlockBigGANpp,
act=act, act=act,
...@@ -630,16 +344,13 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -630,16 +344,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb_dim=nf * 4, temb_dim=nf * 4,
) )
else:
raise ValueError(f"resblock type {resblock_type} unrecognized.")
# Downsampling block # Downsampling block
channels = num_channels channels = num_channels
if progressive_input != "none": if progressive_input != "none":
input_pyramid_ch = channels input_pyramid_ch = channels
modules.append(conv3x3(channels, nf)) modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf] hs_c = [nf]
in_ch = nf in_ch = nf
...@@ -655,9 +366,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -655,9 +366,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if resblock_type == "ddpm":
modules.append(Downsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(down=True, in_ch=in_ch)) modules.append(ResnetBlock(down=True, in_ch=in_ch))
if progressive_input == "input_skip": if progressive_input == "input_skip":
...@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2 in_ch *= 2
elif progressive_input == "residual": elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch input_pyramid_ch = in_ch
hs_c.append(in_ch) hs_c.append(in_ch)
...@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1: if i_level == self.num_resolutions - 1:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, in_ch, bias=True)) modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name.") raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) modules.append(
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if resblock_type == "ddpm":
modules.append(Upsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(in_ch=in_ch, up=True)) modules.append(ResnetBlock(in_ch=in_ch, up=True))
assert not hs_c assert not hs_c
if progressive != "output_skip": if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
...@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
temb = None temb = None
if not self.config.centered:
# If input data is in [0, 1] # If input data is in [0, 1]
if not self.config.centered:
x = 2 * x - 1.0 x = 2 * x - 1.0
# Downsampling block # Downsampling block
...@@ -774,10 +481,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -774,10 +481,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if self.resblock_type == "ddpm":
h = modules[m_idx](hs[-1])
m_idx += 1
else:
h = modules[m_idx](hs[-1], temb) h = modules[m_idx](hs[-1], temb)
m_idx += 1 m_idx += 1
...@@ -851,10 +554,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -851,10 +554,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{self.progressive} is not a valid name") raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if self.resblock_type == "ddpm":
h = modules[m_idx](h)
m_idx += 1
else:
h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
m_idx += 1 m_idx += 1
......
...@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off # fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
...@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels = 3 num_channels = 3
sizes = (32, 32) sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([3.1909e-07, -8.5393e-08, 4.8460e-07, -4.5550e-07, -1.3205e-06, -6.3475e-07, 9.7837e-07, 2.9974e-07, 1.2345e-06]) expected_output_slice = torch.tensor([0.1315, 0.0741, 0.0393, 0.0455, 0.0556, 0.0180, -0.0832, -0.0644, -0.0856])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self): def test_output_pretrained_ve_large(self):
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy") model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
...@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels = 3 num_channels = 3
sizes = (32, 32) sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-8.3299e-07, -9.0431e-07, 4.0585e-08, 9.7563e-07, 1.0280e-06, 1.0133e-06, 1.4979e-06, -2.9716e-07, -6.1817e-07]) expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_vp(self): def test_output_pretrained_vp(self):
model = NCSNpp.from_pretrained("fusing/ddpm-cifar10-vp-dummy") model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
model.eval() model.eval()
model.to(torch_device) model.to(torch_device)
...@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels = 3 num_channels = 3
sizes = (32, 32) sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [9.0]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-3.9086e-07, -1.1001e-05, 1.8881e-06, 1.1106e-05, 1.6629e-06, 2.9820e-06, 8.4978e-06, 8.0253e-07, 1.5435e-06]) expected_output_slice = torch.tensor([0.3303, -0.2275, -2.8872, -0.1309, -1.2861, 3.4567, -1.0083, 2.5325, -1.3866])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class VQModelTests(ModelTesterMixin, unittest.TestCase): class VQModelTests(ModelTesterMixin, unittest.TestCase):
...@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
-0.4218])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
...@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
0.1750])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment