Unverified Commit 53a42d0a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Simplify FirUp/down, unet sde (#71)

* refactor fir up/down sample

* remove variance scaling

* remove variance scaling from unet sde

* refactor Linear

* style

* actually remove variance scaling

* add back upsample_2d, downsample_2d

* style

* fix FirUpsample2D
parent 321f9791
......@@ -175,12 +175,81 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if k is None:
k = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
if self.use_conv:
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = w.shape[1]
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
else:
p = k.shape[0] - factor
x = upfirdn2d_native(
x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
return x
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
return h
......@@ -190,109 +259,61 @@ class FirDownsample2D(nn.Module):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert convW == convH
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
stride = (factor, factor)
k = k * gain
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
if self.use_conv:
_, _, convH, convW = w.shape
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, w, stride=s, padding=0)
else:
p = k.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
return x
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
def forward(self, x):
if self.use_conv:
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
return x
# TODO (patil-suraj): needs test
......@@ -452,18 +473,17 @@ class ResnetBlock2D(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
self.down = down
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.Conv_0 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
self.Conv_1 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.Conv_2 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.in_ch = in_ch
self.out_ch = out_ch
......@@ -716,75 +736,6 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
......@@ -805,9 +756,15 @@ def upsample_2d(x, k=None, factor=2, gain=1):
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
return upfirdn2d_native(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
......@@ -831,16 +788,55 @@ def downsample_2d(x, k=None, factor=2, gain=1):
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
k = k * gain
p = k.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
......@@ -29,57 +29,13 @@ from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
def forward(self, x, y):
......@@ -176,8 +132,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(Linear(embed_dim, nf * 4))
modules.append(Linear(nf * 4, nf * 4))
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
......@@ -205,7 +161,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
......@@ -310,20 +266,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
......@@ -351,7 +305,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment