Commit 1468f754 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish resnet

parent fa7443c8
...@@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin ...@@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d from .resnet import downsample_2d, upfirdn2d, upsample_2d
from .resnet import ResnetBlockBigGANppNew as ResnetBlockBigGANpp from .resnet import ResnetBlock
from .resnet import ResnetBlock as ResnetNew
def _setup_kernel(k): def _setup_kernel(k):
...@@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
# Downsampling block # Downsampling block
channels = num_channels channels = num_channels
...@@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
# Residual blocks for this resolution # Residual blocks for this resolution
for i_block in range(num_res_blocks): for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch, in_channels=in_ch,
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
...@@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
in_ch = hs_c[-1] in_ch = hs_c[-1]
# modules.append(ResnetBlock(in_ch=in_ch))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
) )
) )
modules.append(AttnBlock(channels=in_ch)) modules.append(AttnBlock(channels=in_ch))
# modules.append(ResnetBlock(in_ch=in_ch))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1): for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch + hs_c.pop(), in_channels=in_ch + hs_c.pop(),
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
...@@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append( modules.append(
ResnetNew( ResnetBlock(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment