"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2345481c0e21f1bd84c0d85b866b57d34506d836"
Unverified Commit 53a42d0a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Simplify FirUp/down, unet sde (#71)

* refactor fir up/down sample

* remove variance scaling

* remove variance scaling from unet sde

* refactor Linear

* style

* actually remove variance scaling

* add back upsample_2d, downsample_2d

* style

* fix FirUpsample2D
parent 321f9791
...@@ -175,12 +175,81 @@ class FirUpsample2D(nn.Module): ...@@ -175,12 +175,81 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_channels = out_channels self.out_channels = out_channels
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if k is None:
k = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
if self.use_conv:
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = w.shape[1]
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
else:
p = k.shape[0] - factor
x = upfirdn2d_native(
x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
return x
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1) h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
return h return h
...@@ -190,109 +259,61 @@ class FirDownsample2D(nn.Module): ...@@ -190,109 +259,61 @@ class FirDownsample2D(nn.Module):
super().__init__() super().__init__()
out_channels = out_channels if out_channels else channels out_channels = out_channels if out_channels else channels
if use_conv: if use_conv:
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.use_conv = use_conv self.use_conv = use_conv
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x): def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
if self.use_conv: """Fused `Conv2d()` followed by `downsample_2d()`.
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) Args:
else: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
x = downsample_2d(x, self.fir_kernel, factor=2) efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
return x x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1): factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
"""Fused `Conv2d()` followed by `downsample_2d()`. Scaling factor for signal magnitude (default: 1.0).
Args: Returns:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary datatype as `x`.
order. """
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
# Setup filter kernel. # setup kernel
if k is None: k = np.asarray(k, dtype=np.float32)
k = [1] * factor if k.ndim == 1:
k = _setup_kernel(k) * (gain * (factor**2)) k = np.outer(k, k)
p = (k.shape[0] - factor) - (convW - 1) k /= np.sum(k)
stride = (factor, factor) k = k * gain
# Determine data dimensions. if self.use_conv:
stride = [1, 1, factor, factor] _, _, convH, convW = w.shape
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) p = (k.shape[0] - factor) + (convW - 1)
output_padding = ( s = [factor, factor]
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, x = F.conv2d(x, w, stride=s, padding=0)
) else:
assert output_padding[0] >= 0 and output_padding[1] >= 0 p = k.shape[0] - factor
num_groups = x.shape[1] // inC x = upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
# Transpose weights. return x
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) def forward(self, x):
if self.use_conv:
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) return x
# TODO (patil-suraj): needs test # TODO (patil-suraj): needs test
...@@ -452,18 +473,17 @@ class ResnetBlock2D(nn.Module): ...@@ -452,18 +473,17 @@ class ResnetBlock2D(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up self.up = up
self.down = down self.down = down
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.Conv_0 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps) self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1) self.Conv_1 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) self.Conv_2 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
...@@ -716,75 +736,6 @@ class RearrangeDim(nn.Module): ...@@ -716,75 +736,6 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1): def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
...@@ -805,9 +756,15 @@ def upsample_2d(x, k=None, factor=2, gain=1): ...@@ -805,9 +756,15 @@ def upsample_2d(x, k=None, factor=2, gain=1):
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if k is None: if k is None:
k = [1] * factor k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
p = k.shape[0] - factor p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) return upfirdn2d_native(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, k=None, factor=2, gain=1):
...@@ -831,16 +788,55 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -831,16 +788,55 @@ def downsample_2d(x, k=None, factor=2, gain=1):
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if k is None: if k is None:
k = [1] * factor k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1: if k.ndim == 1:
k = np.outer(k, k) k = np.outer(k, k)
k /= np.sum(k) k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1] k = k * gain
return k p = k.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
...@@ -29,57 +29,13 @@ from .embeddings import GaussianFourierProjection, get_timestep_embedding ...@@ -29,57 +29,13 @@ from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
# 1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0) self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
def forward(self, x, y): def forward(self, x, y):
...@@ -176,8 +132,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -176,8 +132,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
raise ValueError(f"embedding type {embedding_type} unknown.") raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(Linear(embed_dim, nf * 4)) modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(Linear(nf * 4, nf * 4)) modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
...@@ -205,7 +161,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -205,7 +161,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none": if progressive_input != "none":
input_pyramid_ch = channels input_pyramid_ch = channels
modules.append(Conv2d(channels, nf, kernel_size=3, padding=1)) modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf] hs_c = [nf]
in_ch = nf in_ch = nf
...@@ -310,20 +266,18 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -310,20 +266,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1: if i_level == self.num_resolutions - 1:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name.") raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append( modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
...@@ -351,7 +305,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -351,7 +305,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive != "output_skip": if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(Conv2d(in_ch, channels, init_scale=init_scale)) modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment