Commit 26ce60c4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent 358531be
......@@ -330,8 +330,8 @@ class ResBlock(TimestepBlock):
result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return result
......@@ -439,9 +439,9 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
......@@ -466,8 +466,8 @@ class ResnetBlock(nn.Module):
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
......
......@@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, TimestepBlock, Upsample
from .resnet import ResnetBlock
#from .resnet import ResBlock
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
# from .resnet import ResBlock
def exists(val):
......@@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if resblock_updown
else Downsample(
......@@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks:
out_ch = ch
layers.append(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
None
if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
......@@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if resblock_updown
else Downsample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment