Commit 0eac7bd6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

small fix

parent 1e7e23a9
......@@ -237,12 +237,12 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = self.downsample = None
if self.up and kernel == "fir":
......@@ -318,9 +318,9 @@ class ResnetBlock(nn.Module):
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
......@@ -338,7 +338,7 @@ class ResnetBlock(nn.Module):
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
......
......@@ -27,8 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
from .resnet import ResnetBlock
from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k):
......@@ -277,8 +276,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
skip_rescale=skip_rescale,
continuous=continuous,
)
self.act = act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
......@@ -421,9 +418,10 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock(
in_channels=in_ch + hs_c.pop(),
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment