Commit 26ce60c4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent 358531be
...@@ -330,8 +330,8 @@ class ResBlock(TimestepBlock): ...@@ -330,8 +330,8 @@ class ResBlock(TimestepBlock):
result = self.skip_connection(x) + h result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage # TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb) # result = self.forward_2(x, emb)
return result return result
...@@ -439,9 +439,9 @@ class ResnetBlock(nn.Module): ...@@ -439,9 +439,9 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm: elif self.overwrite_for_ldm:
dims = 2 dims = 2
# eps = 1e-5 # eps = 1e-5
# non_linearity = "silu" # non_linearity = "silu"
# overwrite_for_ldm # overwrite_for_ldm
channels = in_channels channels = in_channels
emb_channels = temb_channels emb_channels = temb_channels
use_scale_shift_norm = False use_scale_shift_norm = False
...@@ -466,8 +466,8 @@ class ResnetBlock(nn.Module): ...@@ -466,8 +466,8 @@ class ResnetBlock(nn.Module):
) )
if self.out_channels == in_channels: if self.out_channels == in_channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
# elif use_conv: # elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
......
...@@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin ...@@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
from .resnet import ResnetBlock
#from .resnet import ResBlock
# from .resnet import ResBlock
def exists(val): def exists(val):
...@@ -561,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -561,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
overwrite_for_ldm=True, overwrite_for_ldm=True,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
...@@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
# ResBlock( # ResBlock(
# ch, # ch,
# time_embed_dim, # time_embed_dim,
# dropout, # dropout,
# out_channels=out_ch, # out_channels=out_ch,
# dims=dims, # dims=dims,
# use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
# down=True, # down=True,
# ) # )
None None
if resblock_updown if resblock_updown
else Downsample( else Downsample(
...@@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
# ResBlock( # ResBlock(
# ch, # ch,
# time_embed_dim, # time_embed_dim,
# dropout, # dropout,
# out_channels=out_ch, # out_channels=out_ch,
# dims=dims, # dims=dims,
# use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
# up=True, # up=True,
# ) # )
None None
if resblock_updown if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
...@@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module): ...@@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
# ResBlock( # ResBlock(
# ch, # ch,
# time_embed_dim, # time_embed_dim,
# dropout, # dropout,
# out_channels=out_ch, # out_channels=out_ch,
# dims=dims, # dims=dims,
# use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
# down=True, # down=True,
# ) # )
None None
if resblock_updown if resblock_updown
else Downsample( else Downsample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment