Unverified Commit 1a431ae8 authored by Rashmi Margani's avatar Rashmi Margani Committed by GitHub
Browse files

Rename variables from single letter to meaningful name fix (#395)


Co-authored-by: default avatarRashmi S <rashmis@Rashmis-MacBook-Pro.local>
parent 8d14edf2
......@@ -107,7 +107,7 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
......@@ -116,9 +116,9 @@ class FirUpsample2D(nn.Module):
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
......@@ -130,23 +130,23 @@ class FirUpsample2D(nn.Module):
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if k is None:
k = [1] * factor
if kernel is None:
kernel = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
kernel = np.asarray(kernel, dtype=np.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
k = k * (gain * (factor**2))
kernel = kernel * (gain * (factor**2))
if self.use_conv:
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
convH = weight.shape[2]
convW = weight.shape[3]
inC = weight.shape[1]
p = (k.shape[0] - factor) - (convW - 1)
p = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
......@@ -157,33 +157,33 @@ class FirUpsample2D(nn.Module):
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = w.shape[1]
inC = weight.shape[1]
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
else:
p = k.shape[0] - factor
p = kernel.shape[0] - factor
x = upfirdn2d_native(
x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
return x
def forward(self, x):
if self.use_conv:
h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
return h
return height
class FirDownsample2D(nn.Module):
......@@ -196,7 +196,7 @@ class FirDownsample2D(nn.Module):
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
......@@ -215,35 +215,35 @@ class FirDownsample2D(nn.Module):
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
if kernel is None:
kernel = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
kernel = np.asarray(kernel, dtype=np.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
k = k * gain
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = w.shape
p = (k.shape[0] - factor) + (convW - 1)
_, _, convH, convW = weight.shape
p = (kernel.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, w, stride=s, padding=0)
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, weight, stride=s, padding=0)
else:
p = k.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
p = kernel.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
return x
def forward(self, x):
if self.use_conv:
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel)
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
return x
......@@ -308,7 +308,7 @@ class ResnetBlock2D(nn.Module):
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
......@@ -316,7 +316,7 @@ class ResnetBlock2D(nn.Module):
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
......@@ -370,7 +370,7 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
def upsample_2d(x, k=None, factor=2, gain=1):
def upsample_2d(x, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Args:
......@@ -388,20 +388,22 @@ def upsample_2d(x, k=None, factor=2, gain=1):
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
if kernel is None:
kernel = [1] * factor
kernel = np.asarray(kernel, dtype=np.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = kernel * (gain * (factor**2))
p = kernel.shape[0] - factor
return upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
def downsample_2d(x, k=None, factor=2, gain=1):
def downsample_2d(x, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.
Args:
......@@ -411,7 +413,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
......@@ -420,17 +422,17 @@ def downsample_2d(x, k=None, factor=2, gain=1):
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
if kernel is None:
kernel = [1] * factor
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
kernel = np.asarray(kernel, dtype=np.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
k = k * gain
p = k.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
kernel = kernel * gain
p = kernel.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment