Unverified Commit 1a431ae8 authored by Rashmi Margani's avatar Rashmi Margani Committed by GitHub
Browse files

Rename variables from single letter to meaningful name fix (#395)


Co-authored-by: default avatarRashmi S <rashmis@Rashmis-MacBook-Pro.local>
parent 8d14edf2
...@@ -107,7 +107,7 @@ class FirUpsample2D(nn.Module): ...@@ -107,7 +107,7 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_channels = out_channels self.out_channels = out_channels
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1): def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`. """Fused `upsample_2d()` followed by `Conv2d()`.
Args: Args:
...@@ -116,9 +116,9 @@ class FirUpsample2D(nn.Module): ...@@ -116,9 +116,9 @@ class FirUpsample2D(nn.Module):
order. order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`. C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels, weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
...@@ -130,23 +130,23 @@ class FirUpsample2D(nn.Module): ...@@ -130,23 +130,23 @@ class FirUpsample2D(nn.Module):
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
# Setup filter kernel. # Setup filter kernel.
if k is None: if kernel is None:
k = [1] * factor kernel = [1] * factor
# setup kernel # setup kernel
k = np.asarray(k, dtype=np.float32) kernel = np.asarray(kernel, dtype=np.float32)
if k.ndim == 1: if kernel.ndim == 1:
k = np.outer(k, k) kernel = np.outer(kernel, kernel)
k /= np.sum(k) kernel /= np.sum(kernel)
k = k * (gain * (factor**2)) kernel = kernel * (gain * (factor**2))
if self.use_conv: if self.use_conv:
convH = w.shape[2] convH = weight.shape[2]
convW = w.shape[3] convW = weight.shape[3]
inC = w.shape[1] inC = weight.shape[1]
p = (k.shape[0] - factor) - (convW - 1) p = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor) stride = (factor, factor)
# Determine data dimensions. # Determine data dimensions.
...@@ -157,33 +157,33 @@ class FirUpsample2D(nn.Module): ...@@ -157,33 +157,33 @@ class FirUpsample2D(nn.Module):
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
) )
assert output_padding[0] >= 0 and output_padding[1] >= 0 assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = w.shape[1] inC = weight.shape[1]
num_groups = x.shape[1] // inC num_groups = x.shape[1] // inC
# Transpose weights. # Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
else: else:
p = k.shape[0] - factor p = kernel.shape[0] - factor
x = upfirdn2d_native( x = upfirdn2d_native(
x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
) )
return x return x
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = self._upsample_2d(x, k=self.fir_kernel, factor=2) height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
return h return height
class FirDownsample2D(nn.Module): class FirDownsample2D(nn.Module):
...@@ -196,7 +196,7 @@ class FirDownsample2D(nn.Module): ...@@ -196,7 +196,7 @@ class FirDownsample2D(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.out_channels = out_channels self.out_channels = out_channels
def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1): def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Args: Args:
...@@ -215,35 +215,35 @@ class FirDownsample2D(nn.Module): ...@@ -215,35 +215,35 @@ class FirDownsample2D(nn.Module):
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if k is None: if kernel is None:
k = [1] * factor kernel = [1] * factor
# setup kernel # setup kernel
k = np.asarray(k, dtype=np.float32) kernel = np.asarray(kernel, dtype=np.float32)
if k.ndim == 1: if kernel.ndim == 1:
k = np.outer(k, k) kernel = np.outer(kernel, kernel)
k /= np.sum(k) kernel /= np.sum(kernel)
k = k * gain kernel = kernel * gain
if self.use_conv: if self.use_conv:
_, _, convH, convW = w.shape _, _, convH, convW = weight.shape
p = (k.shape[0] - factor) + (convW - 1) p = (kernel.shape[0] - factor) + (convW - 1)
s = [factor, factor] s = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, w, stride=s, padding=0) x = F.conv2d(x, weight, stride=s, padding=0)
else: else:
p = k.shape[0] - factor p = kernel.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
return x return x
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel) x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = self._downsample_2d(x, k=self.fir_kernel, factor=2) x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
return x return x
...@@ -308,7 +308,7 @@ class ResnetBlock2D(nn.Module): ...@@ -308,7 +308,7 @@ class ResnetBlock2D(nn.Module):
if self.up: if self.up:
if kernel == "fir": if kernel == "fir":
fir_kernel = (1, 3, 3, 1) fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else: else:
...@@ -316,7 +316,7 @@ class ResnetBlock2D(nn.Module): ...@@ -316,7 +316,7 @@ class ResnetBlock2D(nn.Module):
elif self.down: elif self.down:
if kernel == "fir": if kernel == "fir":
fir_kernel = (1, 3, 3, 1) fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else: else:
...@@ -370,7 +370,7 @@ class Mish(torch.nn.Module): ...@@ -370,7 +370,7 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
def upsample_2d(x, k=None, factor=2, gain=1): def upsample_2d(x, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Args: Args:
...@@ -388,20 +388,22 @@ def upsample_2d(x, k=None, factor=2, gain=1): ...@@ -388,20 +388,22 @@ def upsample_2d(x, k=None, factor=2, gain=1):
Tensor of the shape `[N, C, H * factor, W * factor]` Tensor of the shape `[N, C, H * factor, W * factor]`
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if k is None: if kernel is None:
k = [1] * factor kernel = [1] * factor
k = np.asarray(k, dtype=np.float32) kernel = np.asarray(kernel, dtype=np.float32)
if k.ndim == 1: if kernel.ndim == 1:
k = np.outer(k, k) kernel = np.outer(kernel, kernel)
k /= np.sum(k) kernel /= np.sum(kernel)
k = k * (gain * (factor**2)) kernel = kernel * (gain * (factor**2))
p = k.shape[0] - factor p = kernel.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) return upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
def downsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Args: Args:
...@@ -411,7 +413,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -411,7 +413,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
shape is a multiple of the downsampling factor. shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`. C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]` kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling. (separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
...@@ -420,17 +422,17 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -420,17 +422,17 @@ def downsample_2d(x, k=None, factor=2, gain=1):
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if k is None: if kernel is None:
k = [1] * factor kernel = [1] * factor
k = np.asarray(k, dtype=np.float32) kernel = np.asarray(kernel, dtype=np.float32)
if k.ndim == 1: if kernel.ndim == 1:
k = np.outer(k, k) kernel = np.outer(kernel, kernel)
k /= np.sum(k) kernel /= np.sum(kernel)
k = k * gain kernel = kernel * gain
p = k.shape[0] - factor p = kernel.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment