Commit 3e2cff4d authored by patil-suraj's avatar patil-suraj
Browse files

better names and more cleanup

parent 639b8611
...@@ -40,12 +40,8 @@ def _setup_kernel(k): ...@@ -40,12 +40,8 @@ def _setup_kernel(k):
return k return k
def _shape(x, dim): def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
return x.shape[dim] """Fused `upsample_2d()` followed by `Conv2d()`.
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -84,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -84,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions. # Determine data dimensions.
stride = [1, 1, factor, factor] stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = ( output_padding = (
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
) )
assert output_padding[0] >= 0 and output_padding[1] >= 0 assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = _shape(x, 1) // inC num_groups = x.shape[1] // inC
# Transpose weights. # Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
...@@ -98,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): ...@@ -98,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_2d(x, w, k=None, factor=2, gain=1): def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Args: Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -143,15 +130,7 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1): ...@@ -143,15 +130,7 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return F.conv2d(x, w, stride=s, padding=0) return F.conv2d(x, w, stride=s, padding=0)
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
...@@ -170,13 +149,21 @@ def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, devi ...@@ -170,13 +149,21 @@ def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, devi
return init return init
def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = _variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
# 1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0) self.Conv_0 = Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
def forward(self, x, y): def forward(self, x, y):
...@@ -189,38 +176,38 @@ class Combine(nn.Module): ...@@ -189,38 +176,38 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class Upsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if with_conv: if with_conv:
self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv self.with_conv = with_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_ch = out_ch
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
h = upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
return h return h
class Downsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if with_conv: if with_conv:
self.Conv2d_0 = self.Conv2d_0 = conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.with_conv = with_conv
self.out_ch = out_ch self.out_ch = out_ch
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
x = conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else: else:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
...@@ -311,21 +298,21 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -311,21 +298,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
if conditional: if conditional:
modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4)) modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
...@@ -348,7 +335,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -348,7 +335,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none": if progressive_input != "none":
input_pyramid_ch = channels input_pyramid_ch = channels
modules.append(conv2d(channels, nf, kernel_size=3, padding=1)) modules.append(Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf] hs_c = [nf]
in_ch = nf in_ch = nf
...@@ -397,11 +384,11 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -397,11 +384,11 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1: if i_level == self.num_resolutions - 1:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) modules.append(Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name.") raise ValueError(f"{progressive} is not a valid name.")
...@@ -409,7 +396,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -409,7 +396,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append( modules.append(
conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1) Conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
) )
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
...@@ -425,7 +412,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -425,7 +412,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive != "output_skip": if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, init_scale=init_scale)) modules.append(Conv2d(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment