Commit ebf3717c authored by Patrick von Platen's avatar Patrick von Platen
Browse files

resnet in one file

parent e5d9baf0
import string
from abc import abstractmethod
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0): ...@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0):
return x * F.sigmoid(x * float(swish)) return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module): class Upsample(nn.Module):
""" """
An upsampling layer with an optional convolution. An upsampling layer with an optional convolution.
...@@ -134,154 +150,713 @@ class Downsample(nn.Module): ...@@ -134,154 +150,713 @@ class Downsample(nn.Module):
return self.op(x) return self.op(x)
class UNetUpsample(nn.Module): # class UNetUpsample(nn.Module):
def __init__(self, in_channels, with_conv): # def __init__(self, in_channels, with_conv):
super().__init__() # super().__init__()
self.with_conv = with_conv # self.with_conv = with_conv
if self.with_conv: # if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) # self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
#
# def forward(self, x):
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# if self.with_conv:
# x = self.conv(x)
# return x
#
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class GradTTSUpsample(torch.nn.Module):
# def __init__(self, dim):
# super(Upsample, self).__init__()
# self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
#
#
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
# RESNETS
class GlideUpsample(nn.Module): # unet_glide.py & unet_ldm.py
class ResBlock(TimestepBlock):
""" """
An upsampling layer with an optional convolution. A residual block that can optionally change the number of channels.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
upsampling occurs in the inner-two dimensions. use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None): def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.use_checkpoint = use_checkpoint
if use_conv: self.use_scale_shift_norm = use_scale_shift_norm
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
def forward(self, x):
assert x.shape[1] == self.channels # unet.py
if self.dims == 3: class ResnetBlock(nn.Module):
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet_grad_tts.py
class ResnetBlockGradTTS(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlockGradTTS, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") self.res_conv = torch.nn.Identity()
if self.use_conv:
x = self.conv(x) def forward(self, x, mask, time_emb):
return x h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down:
if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_score_estimation.py
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class LDMUpsample(nn.Module): # HELPER Modules
def normalization(channels, swish=0.0):
""" """
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param Make a standard normalization layer, with an optional swish activation.
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then :param channels: number of input channels. :return: an nn.Module for normalization.
upsampling occurs in the inner-two dimensions. """
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
""" """
for p in module.parameters():
p.detach().zero_()
return module
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__() super().__init__()
self.channels = channels
self.out_channels = out_channels or channels self.block = nn.Sequential(
self.use_conv = use_conv nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
self.dims = dims RearrangeDim(),
if use_conv: # Rearrange("batch channels horizon -> batch channels 1 horizon"),
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels return self.block(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
if self.use_conv:
x = self.conv(x)
return x def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
class GradTTSUpsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return self.conv(x) return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
# TODO (patil-suraj): needs test def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
class Upsample1d(nn.Module): _, channel, in_h, in_w = input.shape
def __init__(self, dim): input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__() super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x): def forward(self, x):
return self.conv(x) x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
# class ResnetBlock(nn.Module):
# def __init__(
# self, def _setup_kernel(k):
# *, k = np.asarray(k, dtype=np.float32)
# in_channels, if k.ndim == 1:
# out_channels=None, k = np.outer(k, k)
# conv_shortcut=False, k /= np.sum(k)
# dropout, assert k.ndim == 2
# temb_channels=512, assert k.shape[0] == k.shape[1]
# use_scale_shift_norm=False, return k
# ):
# super().__init__()
# self.in_channels = in_channels def contract_inner(x, y):
# out_channels = in_channels if out_channels is None else out_channels """tensordot(x, y, 1)."""
# self.out_channels = out_channels x_chars = list(string.ascii_lowercase[: len(x.shape)])
# self.use_conv_shortcut = conv_shortcut y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
# self.use_scale_shift_norm = use_scale_shift_norm y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
# self.norm1 = Normalize(in_channels) return _einsum(x_chars, y_chars, out_chars, x, y)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels def _einsum(a, b, c, x, y):
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles) einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResnetBlock, Upsample
def nonlinearity(x): def nonlinearity(x):
...@@ -34,46 +34,46 @@ def Normalize(in_channels): ...@@ -34,46 +34,46 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module): # class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): # def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__() # super().__init__()
self.in_channels = in_channels # self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels # out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels # self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut # self.use_conv_shortcut = conv_shortcut
#
self.norm1 = Normalize(in_channels) # self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) # self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels) # self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) # self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels: # if self.in_channels != self.out_channels:
if self.use_conv_shortcut: # if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else: # else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def forward(self, x, temb): # def forward(self, x, temb):
h = x # h = x
h = self.norm1(h) # h = self.norm1(h)
h = nonlinearity(h) # h = nonlinearity(h)
h = self.conv1(h) # h = self.conv1(h)
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] # h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h = self.norm2(h) # h = self.norm2(h)
h = nonlinearity(h) # h = nonlinearity(h)
h = self.dropout(h) # h = self.dropout(h)
h = self.conv2(h) # h = self.conv2(h)
#
if self.in_channels != self.out_channels: # if self.in_channels != self.out_channels:
if self.use_conv_shortcut: # if self.use_conv_shortcut:
x = self.conv_shortcut(x) # x = self.conv_shortcut(x)
else: # else:
x = self.nin_shortcut(x) # x = self.nin_shortcut(x)
#
return x + h # return x + h
class UNetModel(ModelMixin, ConfigMixin): class UNetModel(ModelMixin, ConfigMixin):
...@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
block = nn.ModuleList() block = nn.ModuleList()
attn = nn.ModuleList() attn = nn.ModuleList()
attn_2 = nn.ModuleList()
block_in = ch * in_ch_mult[i_level] block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
...@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down = nn.Module() down = nn.Module()
down.block = block down.block = block
down.attn = attn down.attn = attn
down.attn_2 = attn_2
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
......
...@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin ...@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -96,16 +96,14 @@ def zero_module(module): ...@@ -96,16 +96,14 @@ def zero_module(module):
return module return module
class TimestepBlock(nn.Module): # class TimestepBlock(nn.Module):
""" # """
Any module where forward() takes timestep embeddings as a second argument. # Any module where forward() takes timestep embeddings as a second argument. #"""
""" #
# @abstractmethod
@abstractmethod # def forward(self, x, emb):
def forward(self, x, emb): # """
""" # Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class ResBlock(TimestepBlock): # class ResBlock(TimestepBlock):
""" # """
A residual block that can optionally change the number of channels. # A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing #
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # def __init__(
downsampling. # self,
""" # channels,
# emb_channels,
def __init__( # dropout,
self, # out_channels=None,
channels, # use_conv=False,
emb_channels, # use_scale_shift_norm=False,
dropout, # dims=2,
out_channels=None, # use_checkpoint=False,
use_conv=False, # up=False,
use_scale_shift_norm=False, # down=False,
dims=2, # ):
use_checkpoint=False, # super().__init__()
up=False, # self.channels = channels
down=False, # self.emb_channels = emb_channels
): # self.dropout = dropout
super().__init__() # self.out_channels = out_channels or channels
self.channels = channels # self.use_conv = use_conv
self.emb_channels = emb_channels # self.use_checkpoint = use_checkpoint
self.dropout = dropout # self.use_scale_shift_norm = use_scale_shift_norm
self.out_channels = out_channels or channels #
self.use_conv = use_conv # self.in_layers = nn.Sequential(
self.use_checkpoint = use_checkpoint # normalization(channels, swish=1.0),
self.use_scale_shift_norm = use_scale_shift_norm # nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self.in_layers = nn.Sequential( # )
normalization(channels, swish=1.0), #
nn.Identity(), # self.updown = up or down
conv_nd(dims, channels, self.out_channels, 3, padding=1), #
) # if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.updown = up or down # self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if up: # self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.h_upd = Upsample(channels, use_conv=False, dims=dims) # self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Upsample(channels, use_conv=False, dims=dims) # else:
elif down: # self.h_upd = self.x_upd = nn.Identity()
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") #
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") # self.emb_layers = nn.Sequential(
else: # nn.SiLU(),
self.h_upd = self.x_upd = nn.Identity() # linear(
# emb_channels,
self.emb_layers = nn.Sequential( # 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn.SiLU(), # ),
linear( # )
emb_channels, # self.out_layers = nn.Sequential(
2 * self.out_channels if use_scale_shift_norm else self.out_channels, # normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
), # nn.SiLU() if use_scale_shift_norm else nn.Identity(),
) # nn.Dropout(p=dropout),
self.out_layers = nn.Sequential( # zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), # )
nn.SiLU() if use_scale_shift_norm else nn.Identity(), #
nn.Dropout(p=dropout), # if self.out_channels == channels:
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), # self.skip_connection = nn.Identity()
) # elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if self.out_channels == channels: # else:
self.skip_connection = nn.Identity() # self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif use_conv: #
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) # def forward(self, x, emb):
else: # """
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) # Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
def forward(self, x, emb): # if self.updown:
""" # in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding. # h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. # x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs. # h = in_conv(h)
""" # else:
if self.updown: # h = self.in_layers(x)
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] # emb_out = self.emb_layers(emb).type(h.dtype)
h = in_rest(x) # while len(emb_out.shape) < len(h.shape):
h = self.h_upd(h) # emb_out = emb_out[..., None]
x = self.x_upd(x) # if self.use_scale_shift_norm:
h = in_conv(h) # out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else: # scale, shift = torch.chunk(emb_out, 2, dim=1)
h = self.in_layers(x) # h = out_norm(h) * (1 + scale) + shift
emb_out = self.emb_layers(emb).type(h.dtype) # h = out_rest(h)
while len(emb_out.shape) < len(h.shape): # else:
emb_out = emb_out[..., None] # h = h + emb_out
if self.use_scale_shift_norm: # h = self.out_layers(h)
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] # return self.skip_connection(x) + h
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
......
...@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin ...@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -34,24 +36,24 @@ class Block(torch.nn.Module): ...@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return output * mask return output * mask
class ResnetBlock(torch.nn.Module): # class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): # def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__() # super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) # self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self.block1 = Block(dim, dim_out, groups=groups) # self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups) # self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out: # if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) # self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else: # else:
self.res_conv = torch.nn.Identity() # self.res_conv = torch.nn.Identity()
#
def forward(self, x, mask, time_emb): # def forward(self, x, mask, time_emb):
h = self.block1(x, mask) # h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) # h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask) # h = self.block2(h, mask)
output = h + self.res_conv(x * mask) # output = h + self.res_conv(x * mask)
return output # return output
class Residual(torch.nn.Module): class Residual(torch.nn.Module):
......
...@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin ...@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def exists(val): def exists(val):
...@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module): ...@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return x[:, :, 0] return x[:, :, 0]
class TimestepBlock(nn.Module): # class TimestepBlock(nn.Module):
""" # """
Any module where forward() takes timestep embeddings as a second argument. # Any module where forward() takes timestep embeddings as a second argument. #"""
""" #
# @abstractmethod
@abstractmethod # def forward(self, x, emb):
def forward(self, x, emb): # """
""" # Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class ResBlock(TimestepBlock): # class A_ResBlock(TimestepBlock):
""" # """
A residual block that can optionally change the number of channels. :param channels: the number of input channels. # A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
a spatial spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for #
downsampling. # def __init__(
""" # self,
# channels,
def __init__( # emb_channels,
self, # dropout,
channels, # out_channels=None,
emb_channels, # use_conv=False,
dropout, # use_scale_shift_norm=False,
out_channels=None, # dims=2,
use_conv=False, # use_checkpoint=False,
use_scale_shift_norm=False, # up=False,
dims=2, # down=False,
use_checkpoint=False, # ):
up=False, # super().__init__()
down=False, # self.channels = channels
): # self.emb_channels = emb_channels
super().__init__() # self.dropout = dropout
self.channels = channels # self.out_channels = out_channels or channels
self.emb_channels = emb_channels # self.use_conv = use_conv
self.dropout = dropout # self.use_checkpoint = use_checkpoint
self.out_channels = out_channels or channels # self.use_scale_shift_norm = use_scale_shift_norm
self.use_conv = use_conv #
self.use_checkpoint = use_checkpoint # self.in_layers = nn.Sequential(
self.use_scale_shift_norm = use_scale_shift_norm # normalization(channels),
# nn.SiLU(),
self.in_layers = nn.Sequential( # conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization(channels), # )
nn.SiLU(), #
conv_nd(dims, channels, self.out_channels, 3, padding=1), # self.updown = up or down
) #
# if up:
self.updown = up or down # self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if up: # elif down:
self.h_upd = Upsample(channels, use_conv=False, dims=dims) # self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Upsample(channels, use_conv=False, dims=dims) # self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif down: # else:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") # self.h_upd = self.x_upd = nn.Identity()
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") #
else: # self.emb_layers = nn.Sequential(
self.h_upd = self.x_upd = nn.Identity() # nn.SiLU(),
# linear(
self.emb_layers = nn.Sequential( # emb_channels,
nn.SiLU(), # 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear( # ),
emb_channels, # )
2 * self.out_channels if use_scale_shift_norm else self.out_channels, # self.out_layers = nn.Sequential(
), # normalization(self.out_channels),
) # nn.SiLU(),
self.out_layers = nn.Sequential( # nn.Dropout(p=dropout),
normalization(self.out_channels), # zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn.SiLU(), # )
nn.Dropout(p=dropout), #
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), # if self.out_channels == channels:
) # self.skip_connection = nn.Identity()
# elif use_conv:
if self.out_channels == channels: # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self.skip_connection = nn.Identity() # else:
elif use_conv: # self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) #
else: # def forward(self, x, emb):
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) # if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def forward(self, x, emb): # h = in_rest(x)
if self.updown: # h = self.h_upd(h)
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] # x = self.x_upd(x)
h = in_rest(x) # h = in_conv(h)
h = self.h_upd(h) # else:
x = self.x_upd(x) # h = self.in_layers(x)
h = in_conv(h) # emb_out = self.emb_layers(emb).type(h.dtype)
else: # while len(emb_out.shape) < len(h.shape):
h = self.in_layers(x) # emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype) # if self.use_scale_shift_norm:
while len(emb_out.shape) < len(h.shape): # out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out = emb_out[..., None] # scale, shift = torch.chunk(emb_out, 2, dim=1)
if self.use_scale_shift_norm: # h = out_norm(h) * (1 + scale) + shift
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] # h = out_rest(h)
scale, shift = torch.chunk(emb_out, 2, dim=1) # else:
h = out_norm(h) * (1 + scale) + shift # h = h + emb_out
h = out_rest(h) # h = self.out_layers(h)
else: # return self.skip_connection(x) + h
h = h + emb_out #
h = self.out_layers(h)
return self.skip_connection(x) + h
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
......
...@@ -6,6 +6,7 @@ import torch.nn as nn ...@@ -6,6 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import ResidualTemporalBlock
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
...@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module): ...@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
class ResidualTemporalBlock(nn.Module): # class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): # def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__() # super().__init__()
#
self.blocks = nn.ModuleList( # self.blocks = nn.ModuleList(
[ # [
Conv1dBlock(inp_channels, out_channels, kernel_size), # Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size), # Conv1dBlock(out_channels, out_channels, kernel_size),
] # ]
) # )
#
self.time_mlp = nn.Sequential( # self.time_mlp = nn.Sequential(
nn.Mish(), # nn.Mish(),
nn.Linear(embed_dim, out_channels), # nn.Linear(embed_dim, out_channels),
RearrangeDim(), # RearrangeDim(),
# Rearrange("batch t -> batch t 1"), # Rearrange("batch t -> batch t 1"),
) # )
#
self.residual_conv = ( # self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() # nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
) # )
#
def forward(self, x, t): # def forward(self, x, t):
""" # """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x # x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ] out_channels x horizon ] #"""
""" # out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[0](x) + self.time_mlp(t) # out = self.blocks[1](out)
out = self.blocks[1](out) # return out + self.residual_conv(x)
return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
......
...@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin ...@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization.""" """1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
...@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad ...@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return conv return conv
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization.""" """3x3 convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
...@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc ...@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return conv return conv
conv1x1 = ddpm_conv1x1
conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y): def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y) return torch.einsum(einsum_str, x, y)
...@@ -494,135 +491,135 @@ class Downsample(nn.Module): ...@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return x return x
class ResnetBlockDDPMpp(nn.Module): # class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM.""" # """ResBlock adapted from DDPM."""
#
def __init__( # def __init__(
self, # self,
act, # act,
in_ch, # in_ch,
out_ch=None, # out_ch=None,
temb_dim=None, # temb_dim=None,
conv_shortcut=False, # conv_shortcut=False,
dropout=0.1, # dropout=0.1,
skip_rescale=False, # skip_rescale=False,
init_scale=0.0, # init_scale=0.0,
): # ):
super().__init__() # super().__init__()
out_ch = out_ch if out_ch else in_ch # out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch) # self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None: # if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) # self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) # self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias) # nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) # self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) # self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) # self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch: # if in_ch != out_ch:
if conv_shortcut: # if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch) # self.Conv_2 = conv3x3(in_ch, out_ch)
else: # else:
self.NIN_0 = NIN(in_ch, out_ch) # self.NIN_0 = NIN(in_ch, out_ch)
#
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
self.act = act # self.act = act
self.out_ch = out_ch # self.out_ch = out_ch
self.conv_shortcut = conv_shortcut # self.conv_shortcut = conv_shortcut
#
def forward(self, x, temb=None): # def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x)) # h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h) # h = self.Conv_0(h)
if temb is not None: # if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None] # h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h)) # h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h) # h = self.Dropout_0(h)
h = self.Conv_1(h) # h = self.Conv_1(h)
if x.shape[1] != self.out_ch: # if x.shape[1] != self.out_ch:
if self.conv_shortcut: # if self.conv_shortcut:
x = self.Conv_2(x) # x = self.Conv_2(x)
else: # else:
x = self.NIN_0(x) # x = self.NIN_0(x)
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
class ResnetBlockBigGANpp(nn.Module): # class ResnetBlockBigGANpp(nn.Module):
def __init__( # def __init__(
self, # self,
act, # act,
in_ch, # in_ch,
out_ch=None, # out_ch=None,
temb_dim=None, # temb_dim=None,
up=False, # up=False,
down=False, # down=False,
dropout=0.1, # dropout=0.1,
fir=False, # fir=False,
fir_kernel=(1, 3, 3, 1), # fir_kernel=(1, 3, 3, 1),
skip_rescale=True, # skip_rescale=True,
init_scale=0.0, # init_scale=0.0,
): # ):
super().__init__() # super().__init__()
#
out_ch = out_ch if out_ch else in_ch # out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up # self.up = up
self.down = down # self.down = down
self.fir = fir # self.fir = fir
self.fir_kernel = fir_kernel # self.fir_kernel = fir_kernel
#
self.Conv_0 = conv3x3(in_ch, out_ch) # self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None: # if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) # self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) # self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) # nn.init.zeros_(self.Dense_0.bias)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) # self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) # self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) # self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down: # if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) # self.Conv_2 = conv1x1(in_ch, out_ch)
#
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
self.act = act # self.act = act
self.in_ch = in_ch # self.in_ch = in_ch
self.out_ch = out_ch # self.out_ch = out_ch
#
def forward(self, x, temb=None): # def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x)) # h = self.act(self.GroupNorm_0(x))
#
if self.up: # if self.up:
if self.fir: # if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2) # h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2) # x = upsample_2d(x, self.fir_kernel, factor=2)
else: # else:
h = naive_upsample_2d(h, factor=2) # h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2) # x = naive_upsample_2d(x, factor=2)
elif self.down: # elif self.down:
if self.fir: # if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2) # h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2) # x = downsample_2d(x, self.fir_kernel, factor=2)
else: # else:
h = naive_downsample_2d(h, factor=2) # h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2) # x = naive_downsample_2d(x, factor=2)
#
h = self.Conv_0(h) # h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
if temb is not None: # if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None] # h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h)) # h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h) # h = self.Dropout_0(h)
h = self.Conv_1(h) # h = self.Conv_1(h)
#
if self.in_ch != self.out_ch or self.up or self.down: # if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x) # x = self.Conv_2(x)
#
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment