Commit 3a15afac authored by Patrick von Platen's avatar Patrick von Platen
Browse files

delete bogus files

parent 571e4062
from abc import abstractmethod
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
# RESNETS
# unet_score_estimation.py
class ResnetBlockBigGANppNew(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
overwrite=True,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
if self.overwrite:
in_channels = in_ch
out_channels = out_ch
groups = min(in_ch // 4, 32)
eps = 1e-6
self.pre_norm = True
temb_channels = temb_dim
non_linearity = "silu"
time_embedding_norm = "default"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
def forward_2(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
from abc import abstractmethod
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
# RESNETS
# unet_score_estimation.py
class ResnetBlockBigGANppNew(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
overwrite=True,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
self.is_overwritten = False
self.overwrite = overwrite
if overwrite:
self.in_channels = in_channels = in_ch
self.out_channels = out_channels = out_ch
groups = min(in_ch // 4, 32)
eps = 1e-6
self.pre_norm = True
temb_channels = temb_dim
non_linearity = "silu"
self.time_embedding_norm = time_embedding_norm = "default"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def set_weights(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def set_weights_ldm(self):
import ipdb; ipdb.set_trace()
def forward(self, x, temb=None):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
orig_x = x
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
result = x + h
else:
result = (x + h) / np.sqrt(2.0)
result_2 = self.forward_2(orig_x, temb)
print("Diff", (result - result_2).abs().sum())
import ipdb; ipdb.set_trace()
return result
def forward_2(self, x, temb, mask=1.0):
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment