"vscode:/vscode.git/clone" did not exist on "2746aac3aeac5a0a50fff00d75f22bbf5a0948ca"
Commit fd6f93b2 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

all glide passes

parent db934c67
......@@ -378,9 +378,6 @@ class ResBlock(TimestepBlock):
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
......@@ -426,7 +423,7 @@ class ResnetBlock(nn.Module):
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if time_embedding_norm == "scale_shift":
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
......@@ -489,7 +486,7 @@ class ResnetBlock(nn.Module):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
......@@ -551,9 +548,6 @@ class ResnetBlock(nn.Module):
self.set_weights_ldm()
self.is_overwritten = True
if self.up or self.down:
x = self.x_upd(x)
h = x
h = h * mask
if self.pre_norm:
......@@ -561,6 +555,7 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
......@@ -571,7 +566,6 @@ class ResnetBlock(nn.Module):
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
......@@ -595,9 +589,9 @@ class ResnetBlock(nn.Module):
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
x = self.nin_shortcut(x)
return x + h
......
......@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
from .resnet import ResnetBlock
def convert_module_to_f16(l):
......@@ -101,7 +102,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, TimestepBlock):
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
......@@ -190,14 +191,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
)
]
ch = int(mult * model_channels)
......@@ -218,15 +229,26 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock(
in_channels=ch,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
down=True
)
if resblock_updown
else Downsample(
......@@ -240,13 +262,22 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
),
AttentionBlock(
ch,
......@@ -255,14 +286,23 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
)
)
self._feature_size += ch
......@@ -271,15 +311,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
),
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
......@@ -295,14 +345,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock(
in_channels=ch,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift",
overwrite_for_glide=True,
up=True,
)
if resblock_updown
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment