Commit 61dc6574 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

more fixes

parent f1aade05
from abc import abstractmethod
import functools
import numpy as np
import torch
import torch.nn as nn
......@@ -374,15 +375,20 @@ class ResnetBlock(nn.Module):
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
fir_kernel=(1, 3, 3, 1),
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
......@@ -393,6 +399,13 @@ class ResnetBlock(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if use_nin_shortcut is None:
use_nin_shortcut = self.in_channels != self.out_channels
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
......@@ -406,7 +419,7 @@ class ResnetBlock(nn.Module):
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
......@@ -417,14 +430,17 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
self.nin_shortcut = None
if use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
......@@ -432,6 +448,7 @@ class ResnetBlock(nn.Module):
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
self.overwrite_for_score_vde = overwrite_for_score_vde
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
......@@ -450,6 +467,7 @@ class ResnetBlock(nn.Module):
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
non_linearity = "silu"
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
......@@ -473,6 +491,45 @@ class ResnetBlock(nn.Module):
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif self.overwrite_for_score_vde:
in_ch = in_channels
out_ch = out_channels
eps = 1e-6
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
self.is_overwritten = False
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
......@@ -512,6 +569,24 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def set_weights_score_vde(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels or self.up or self.down:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
......@@ -521,6 +596,9 @@ class ResnetBlock(nn.Module):
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde()
self.is_overwritten = True
h = x
h = h * mask
......@@ -528,10 +606,17 @@ class ResnetBlock(nn.Module):
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
# if self.up: or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
#
h = self.conv1(h)
if not self.pre_norm:
......@@ -563,7 +648,7 @@ class ResnetBlock(nn.Module):
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment