Commit fa7443c8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish resnet

parent 8d7771d8
...@@ -380,7 +380,7 @@ class ResnetBlock(nn.Module): ...@@ -380,7 +380,7 @@ class ResnetBlock(nn.Module):
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
time_embedding_norm="default", time_embedding_norm="default",
fir_kernel=(1, 3, 3, 1), kernel=None,
output_scale_factor=1.0, output_scale_factor=1.0,
use_nin_shortcut=None, use_nin_shortcut=None,
up=False, up=False,
...@@ -433,8 +433,18 @@ class ResnetBlock(nn.Module): ...@@ -433,8 +433,18 @@ class ResnetBlock(nn.Module):
# elif down: # elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None self.upsample = self.downsample = None
if self.up and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif self.up and kernel is None:
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
elif self.down and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif self.down and kernel is None:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
...@@ -505,8 +515,6 @@ class ResnetBlock(nn.Module): ...@@ -505,8 +515,6 @@ class ResnetBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up self.up = up
self.down = down self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
...@@ -525,11 +533,6 @@ class ResnetBlock(nn.Module): ...@@ -525,11 +533,6 @@ class ResnetBlock(nn.Module):
self.out_ch = out_ch self.out_ch = out_ch
# TODO(Patrick) - move to main init # TODO(Patrick) - move to main init
if self.up:
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
if self.down:
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
self.is_overwritten = False self.is_overwritten = False
def set_weights_grad_tts(self): def set_weights_grad_tts(self):
......
...@@ -348,7 +348,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -348,7 +348,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_block in range(num_res_blocks): for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) # modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules.append(ResnetNew( modules.append(
ResnetNew(
in_channels=in_ch, in_channels=in_ch,
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
...@@ -357,7 +358,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -357,7 +358,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
groups=min(in_ch // 4, 32), groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32), groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True, overwrite_for_score_vde=True,
)) )
)
in_ch = out_ch in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions: if all_resolutions[i_level] in attn_resolutions:
...@@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
modules.append(ResnetBlock(down=True, in_ch=in_ch)) # modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip": if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
...@@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
in_ch = hs_c[-1] in_ch = hs_c[-1]
modules.append(ResnetBlock(in_ch=in_ch)) # modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch)) modules.append(AttnBlock(channels=in_ch))
modules.append(ResnetBlock(in_ch=in_ch)) # modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
pyramid_ch = 0 pyramid_ch = 0
# Upsampling block # Upsampling block
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1): for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) # modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
modules.append(
ResnetNew(
in_channels=in_ch + hs_c.pop(),
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions: if all_resolutions[i_level] in attn_resolutions:
...@@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
modules.append(ResnetBlock(in_ch=in_ch, up=True)) # modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir",
use_nin_shortcut=True,
)
)
assert not hs_c assert not hs_c
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment