Unverified Commit 8b0bc596 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Merge pull request #52 from huggingface/clean-unet-sde

Clean UNetNCSNpp
parents 7e0fd19f f35387b3
...@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
up=False, up=False,
down=False, down=False,
dropout=0.1, dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
skip_rescale=True, skip_rescale=True,
init_scale=0.0, init_scale=0.0,
...@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up self.up = up
self.down = down self.down = down
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.act = act self.act = act
...@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
if self.fir: h = upsample_2d(h, self.fir_kernel, factor=2)
h = upsample_2d(h, self.fir_kernel, factor=2) x = upsample_2d(x, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down: elif self.down:
if self.fir: h = downsample_2d(h, self.fir_kernel, factor=2)
h = downsample_2d(h, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
...@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_score_estimation.py
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py # unet_rl.py
class ResidualTemporalBlock(nn.Module): class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
...@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module): ...@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""1x1 convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
...@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = (fan_in + fan_out) / 2
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal": return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
...@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def _setup_kernel(k): def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1: if k.ndim == 1:
...@@ -998,17 +881,3 @@ def _setup_kernel(k): ...@@ -998,17 +881,3 @@ def _setup_kernel(k):
assert k.ndim == 2 assert k.ndim == 2
assert k.shape[0] == k.shape[1] assert k.shape[0] == k.shape[1]
return k return k
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import math import math
import string
import numpy as np import numpy as np
import torch import torch
...@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin ...@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def _setup_kernel(k):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): k /= np.sum(k)
_, channel, in_h, in_w = input.shape assert k.ndim == 2
input = input.reshape(-1, in_h, in_w, 1) assert k.shape[0] == k.shape[1]
return k
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
# Function ported from StyleGAN2
def get_weight(module, shape, weight_var="weight", kernel_init=None):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return module.param(weight_var, kernel_init, shape)
class Conv2d(nn.Module):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def __init__(
self,
in_ch,
out_ch,
kernel,
up=False,
down=False,
resample_kernel=(1, 3, 3, 1),
use_bias=True,
kernel_init=None,
):
super().__init__()
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
if kernel_init is not None:
self.weight.data = kernel_init(self.weight.data.shape)
if use_bias:
self.bias = nn.Parameter(torch.zeros(out_ch))
self.up = up
self.down = down
self.resample_kernel = resample_kernel
self.kernel = kernel
self.use_bias = use_bias
def forward(self, x):
if self.up:
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
elif self.down:
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
else:
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
if self.use_bias:
x = x + self.bias.reshape(1, -1, 1, 1)
return x
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
def upsample_conv_2d(x, w, k=None, factor=2, gain=1): def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. """Fused `upsample_2d()` followed by `Conv2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions. # Determine data dimensions.
stride = [1, 1, factor, factor] stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = ( output_padding = (
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
) )
assert output_padding[0] >= 0 and output_padding[1] >= 0 assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = _shape(x, 1) // inC num_groups = x.shape[1] // inC
# Transpose weights. # Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
...@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_2d(x, w, k=None, factor=2, gain=1): def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1): ...@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return F.conv2d(x, w, stride=s, padding=0) return F.conv2d(x, w, stride=s, padding=0)
def _setup_kernel(k): def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _shape(x, dim):
return x.shape[dim]
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def get_act(nonlinearity):
"""Get activation functions from the config file."""
if nonlinearity.lower() == "elu":
return nn.ELU()
elif nonlinearity.lower() == "relu":
return nn.ReLU()
elif nonlinearity.lower() == "lrelu":
return nn.LeakyReLU(negative_slope=0.2)
elif nonlinearity.lower() == "swish":
return nn.SiLU()
else:
raise NotImplementedError("activation function does not exist!")
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
...@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = (fan_in + fan_out) / 2
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal": return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
self.Conv_0 = conv1x1(dim1, dim2) # 1x1 convolution with DDPM initialization.
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
def forward(self, x, y): def forward(self, x, y):
...@@ -413,80 +183,40 @@ class Combine(nn.Module): ...@@ -413,80 +183,40 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class Upsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if not fir: if with_conv:
if with_conv: self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.Conv_0 = conv3x3(in_ch, out_ch)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.with_conv = with_conv self.with_conv = with_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_ch = out_ch
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape if self.with_conv:
if not self.fir: h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = F.interpolate(x, (H * 2, W * 2), "nearest")
if self.with_conv:
h = self.Conv_0(h)
else: else:
if not self.with_conv: h = upsample_2d(x, self.fir_kernel, factor=2)
h = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = self.Conv2d_0(x)
return h return h
class Downsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if not fir: if with_conv:
if with_conv: self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.with_conv = with_conv
self.out_ch = out_ch self.out_ch = out_ch
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape if self.with_conv:
if not self.fir: x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool2d(x, 2, stride=2)
else: else:
if not self.with_conv: x = downsample_2d(x, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -496,63 +226,52 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -496,63 +226,52 @@ class NCSNpp(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
centered=False,
image_size=1024, image_size=1024,
num_channels=3, num_channels=3,
attention_type="ddpm",
attn_resolutions=(16,), attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True, conditional=True,
conv_size=3, conv_size=3,
dropout=0.0, dropout=0.0,
embedding_type="fourier", embedding_type="fourier",
fir=True, fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
fourier_scale=16, fourier_scale=16,
init_scale=0.0, init_scale=0.0,
nf=16, nf=16,
nonlinearity="swish",
normalization="GroupNorm",
num_res_blocks=1, num_res_blocks=1,
progressive="output_skip", progressive="output_skip",
progressive_combine="sum", progressive_combine="sum",
progressive_input="input_skip", progressive_input="input_skip",
resamp_with_conv=True, resamp_with_conv=True,
resblock_type="biggan",
scale_by_sigma=True, scale_by_sigma=True,
skip_rescale=True, skip_rescale=True,
continuous=True, continuous=True,
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
centered=centered,
image_size=image_size, image_size=image_size,
num_channels=num_channels, num_channels=num_channels,
attention_type=attention_type,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
ch_mult=ch_mult, ch_mult=ch_mult,
conditional=conditional, conditional=conditional,
conv_size=conv_size, conv_size=conv_size,
dropout=dropout, dropout=dropout,
embedding_type=embedding_type, embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel, fir_kernel=fir_kernel,
fourier_scale=fourier_scale, fourier_scale=fourier_scale,
init_scale=init_scale, init_scale=init_scale,
nf=nf, nf=nf,
nonlinearity=nonlinearity,
normalization=normalization,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
progressive=progressive, progressive=progressive,
progressive_combine=progressive_combine, progressive_combine=progressive_combine,
progressive_input=progressive_input, progressive_input=progressive_input,
resamp_with_conv=resamp_with_conv, resamp_with_conv=resamp_with_conv,
resblock_type=resblock_type,
scale_by_sigma=scale_by_sigma, scale_by_sigma=scale_by_sigma,
skip_rescale=skip_rescale, skip_rescale=skip_rescale,
continuous=continuous, continuous=continuous,
) )
self.act = act = get_act(nonlinearity) self.act = act = nn.SiLU()
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
...@@ -562,7 +281,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -562,7 +281,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.conditional = conditional self.conditional = conditional
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.resblock_type = resblock_type
self.progressive = progressive self.progressive = progressive
self.progressive_input = progressive_input self.progressive_input = progressive_input
self.embedding_type = embedding_type self.embedding_type = embedding_type
...@@ -585,53 +303,33 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -585,53 +303,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
raise ValueError(f"embedding type {embedding_type} unknown.") raise ValueError(f"embedding type {embedding_type} unknown.")
if conditional: modules.append(Linear(embed_dim, nf * 4))
modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
if resblock_type == "ddpm":
ResnetBlock = functools.partial(
ResnetBlockDDPMpp,
act=act,
dropout=dropout,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
elif resblock_type == "biggan":
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
else: ResnetBlock = functools.partial(
raise ValueError(f"resblock type {resblock_type} unrecognized.") ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
# Downsampling block # Downsampling block
...@@ -639,7 +337,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -639,7 +337,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none": if progressive_input != "none":
input_pyramid_ch = channels input_pyramid_ch = channels
modules.append(conv3x3(channels, nf)) modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf] hs_c = [nf]
in_ch = nf in_ch = nf
...@@ -655,10 +353,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -655,10 +353,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if resblock_type == "ddpm": modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append(Downsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
if progressive_input == "input_skip": if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
...@@ -691,18 +386,20 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -691,18 +386,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1: if i_level == self.num_resolutions - 1:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, in_ch, bias=True)) modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name.") raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) modules.append(
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
...@@ -711,16 +408,13 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -711,16 +408,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if resblock_type == "ddpm": modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append(Upsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
assert not hs_c assert not hs_c
if progressive != "output_skip": if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
...@@ -751,9 +445,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -751,9 +445,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
temb = None temb = None
if not self.config.centered: # If input data is in [0, 1]
# If input data is in [0, 1] x = 2 * x - 1.0
x = 2 * x - 1.0
# Downsampling block # Downsampling block
input_pyramid = None input_pyramid = None
...@@ -774,12 +467,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -774,12 +467,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if self.resblock_type == "ddpm": h = modules[m_idx](hs[-1], temb)
h = modules[m_idx](hs[-1]) m_idx += 1
m_idx += 1
else:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == "input_skip": if self.progressive_input == "input_skip":
input_pyramid = self.pyramid_downsample(input_pyramid) input_pyramid = self.pyramid_downsample(input_pyramid)
...@@ -851,12 +540,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -851,12 +540,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{self.progressive} is not a valid name") raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if self.resblock_type == "ddpm": h = modules[m_idx](h, temb)
h = modules[m_idx](h) m_idx += 1
m_idx += 1
else:
h = modules[m_idx](h, temb)
m_idx += 1
assert not hs assert not hs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment