Commit 0eac7bd6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

small fix

parent 1e7e23a9
...@@ -237,12 +237,12 @@ class ResnetBlock(nn.Module): ...@@ -237,12 +237,12 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu": elif non_linearity == "silu":
self.nonlinearity = nn.SiLU() self.nonlinearity = nn.SiLU()
# if up: # if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2) # self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2) # self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down: # elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = self.downsample = None self.upsample = self.downsample = None
if self.up and kernel == "fir": if self.up and kernel == "fir":
...@@ -318,9 +318,9 @@ class ResnetBlock(nn.Module): ...@@ -318,9 +318,9 @@ class ResnetBlock(nn.Module):
num_groups = min(in_ch // 4, 32) num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32) num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0) # output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu" # non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True # use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up self.up = up
...@@ -338,7 +338,7 @@ class ResnetBlock(nn.Module): ...@@ -338,7 +338,7 @@ class ResnetBlock(nn.Module):
# 1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
......
...@@ -27,8 +27,7 @@ from ..configuration_utils import ConfigMixin ...@@ -27,8 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d
from .resnet import ResnetBlock
def _setup_kernel(k): def _setup_kernel(k):
...@@ -277,8 +276,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -277,8 +276,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
skip_rescale=skip_rescale, skip_rescale=skip_rescale,
continuous=continuous, continuous=continuous,
) )
self.act = act = nn.SiLU()
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
...@@ -421,9 +418,10 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -421,9 +418,10 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1): for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append( modules.append(
ResnetBlock( ResnetBlock(
in_channels=in_ch + hs_c.pop(), in_channels=in_ch,
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment