Unverified Commit ba3c9a9a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDE] Merge to unconditional model (#89)

* up

* more

* uP

* make dummy test pass

* save intermediate

* p

* p

* finish

* finish

* finish
parent b5c684f0
......@@ -100,7 +100,7 @@ def test_output_pretrained_ldm():
# 2. DDPM
def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model = UNetUnconditionalModel.from_pretrained(model_id, ldm=True)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
......@@ -123,3 +123,16 @@ def get_model(model_id):
# e.g.
get_model("fusing/ddpm-cifar10")
# 3. NCSNpp
# Repos to convert and port to google (part of https://github.com/yang-song/score_sde)
# - https://huggingface.co/fusing/ffhq_ncsnpp
# - https://huggingface.co/fusing/church_256-ncsnpp-ve
# - https://huggingface.co/fusing/celebahq_256-ncsnpp-ve
# - https://huggingface.co/fusing/bedroom_256-ncsnpp-ve
# - https://huggingface.co/fusing/ffhq_256-ncsnpp-ve
# tests to make sure to pass
# - test_score_sde_ve_pipeline (in PipelineTesterMixin)
# - test_output_pretrained_ve_mid, test_output_pretrained_ve_large (in NCSNppModelTests)
......@@ -6,166 +6,6 @@ import torch.nn.functional as F
from torch import nn
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x, encoder_states=None):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
......@@ -216,6 +56,7 @@ class AttentionBlockNew(nn.Module):
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
......@@ -229,9 +70,9 @@ class AttentionBlockNew(nn.Module):
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
......@@ -263,6 +104,20 @@ class AttentionBlockNew(nn.Module):
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T
self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
......@@ -452,3 +307,137 @@ class GEGLU(nn.Module):
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
......@@ -54,14 +54,20 @@ def get_timestep_embedding(
return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
......@@ -87,96 +87,15 @@ class Downsample2D(nn.Module):
self.conv = conv
def forward(self, x):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
# print("x", x.abs().sum())
self.hey = x
assert x.shape[1] == self.channels
x = self.conv(x)
self.yas = x
# print("x", x.abs().sum())
return x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module):
......@@ -330,15 +249,137 @@ class FirDownsample2D(nn.Module):
return x
# TODO (patil-suraj): needs test
# class Upsample2D1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, hey=False):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + h) / self.output_scale_factor
return out
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# THE FOLLOWING SHOULD BE DELETED ONCE ALL CHECKPOITNS ARE CONVERTED
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
......@@ -502,6 +543,7 @@ class ResnetBlock2D(nn.Module):
self.in_ch = in_ch
self.out_ch = out_ch
self.set_weights_score_vde()
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
......@@ -559,17 +601,21 @@ class ResnetBlock2D(nn.Module):
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0):
def forward(self, x, temb, hey=False, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde()
self.is_overwritten = True
# elif self.overwrite_for_score_vde and not self.is_overwritten:
# self.set_weights_score_vde()
# self.is_overwritten = True
# h2 tensor(110029.2109)
# h3 tensor(49596.9492)
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
......@@ -619,154 +665,9 @@ class ResnetBlock2D(nn.Module):
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return (x + h) / self.output_scale_factor
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if self.pre_norm:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
return (x + h) / self.output_scale_factor
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
out = (x + h) / self.output_scale_factor
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
return out
# TODO(Patrick) - just there to convert the weights; can delete afterward
......@@ -778,39 +679,6 @@ class Block(torch.nn.Module):
)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
......
......@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import numpy as np
# limitations under the License.
import torch
from torch import nn
from .attention import AttentionBlockNew
from .resnet import Downsample2D, ResnetBlock, Upsample2D
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
def get_down_block(
......@@ -54,6 +56,29 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
def get_up_block(
......@@ -91,6 +116,30 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResSkipUpBlock2D":
return UNetResSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock2D(nn.Module):
......@@ -113,6 +162,7 @@ class UNetMidBlock2D(nn.Module):
super().__init__()
self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
......@@ -138,6 +188,7 @@ class UNetMidBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)
resnets.append(
......@@ -160,7 +211,6 @@ class UNetMidBlock2D(nn.Module):
def forward(self, hidden_states, temb=None, encoder_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.attention_type == "default":
hidden_states = attn(hidden_states)
......@@ -318,6 +368,178 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
downsample_padding=1,
add_downsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_downsample=True,
downsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module):
def __init__(
self,
......@@ -457,3 +679,213 @@ class UNetResUpBlock2D(nn.Module):
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
upsample_padding=1,
add_upsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(resnet_in_channels + res_skip_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = self.attentions[0](hidden_states)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_upsample=True,
upsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
......@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde=True,
)
)
self.mid.resnets[0] = modules[len(modules) - 3]
self.mid.attentions[0] = modules[len(modules) - 2]
self.mid.resnets[1] = modules[len(modules) - 1]
# self.mid.resnets[0] = modules[len(modules) - 3]
# self.mid.attentions[0] = modules[len(modules) - 2]
# self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0
# Upsampling block
......@@ -282,22 +282,22 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
# elif progressive == "residual":
# modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
# modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
# elif progressive == "residual":
# modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(
......@@ -332,7 +332,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings.
used_sigmas = timesteps
temb = modules[m_idx](torch.log(used_sigmas))
temb = modules[m_idx](used_sigmas)
m_idx += 1
elif self.embedding_type == "positional":
......@@ -363,6 +363,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs = [modules[m_idx](x)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
......@@ -394,16 +395,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h)
# h = hs[-1]
# h = modules[m_idx](h, temb)
# m_idx += 1
# h = modules[m_idx](h)
# m_idx += 1
# h = modules[m_idx](h, temb)
# m_idx += 1
h = self.mid(h, temb)
m_idx += 3
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
pyramid = None
......@@ -424,31 +422,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
elif self.progressive == "residual":
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
else:
raise ValueError(f"{self.progressive} is not a valid name.")
# elif self.progressive == "residual":
# pyramid = self.act(modules[m_idx](h))
# m_idx += 1
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# else:
# raise ValueError(f"{self.progressive} is not a valid name.")
else:
if self.progressive == "output_skip":
pyramid = self.pyramid_upsample(pyramid)
pyramid_h = self.act(modules[m_idx](h))
m_idx += 1
pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1
pyramid = pyramid + pyramid_h
elif self.progressive == "residual":
pyramid = modules[m_idx](pyramid)
m_idx += 1
if self.skip_rescale:
pyramid = (pyramid + h) / np.sqrt(2.0)
else:
pyramid = pyramid + h
h = pyramid
else:
raise ValueError(f"{self.progressive} is not a valid name")
skip_sample = self.pyramid_upsample(pyramid)
pyramid = skip_sample + pyramid_h
# elif self.progressive == "residual":
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# if self.skip_rescale:
# pyramid = (pyramid + h) / np.sqrt(2.0)
# else:
# pyramid = pyramid + h
# h = pyramid
# else:
# raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0:
h = modules[m_idx](h, temb)
......
import functools
import math
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D, get_down_block, get_up_block
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = nn.SiLU()
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
......@@ -72,6 +117,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels=32,
flip_sin_to_cos=True,
downscale_freq_shift=0,
time_embedding_type="positional",
mid_block_scale_factor=1,
center_input_sample=False,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
......@@ -86,7 +134,24 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
ch_mult=None,
ch=None,
ddpm=False,
# ======================================
# SDE
sde=False,
nf=None,
fir=None,
progressive=None,
progressive_combine=None,
scale_by_sigma=None,
skip_rescale=None,
num_channels=None,
centered=False,
conditional=True,
conv_size=3,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
progressive_input="input_skip",
resnet_num_groups=32,
continuous=True,
):
super().__init__()
# register all __init__ params to be accessible via `self.config.<...>`
......@@ -101,19 +166,43 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
time_embedding_type=time_embedding_type,
attention_resolutions=attention_resolutions,
attn_resolutions=attn_resolutions,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
# to delete later
ldm=ldm,
ddpm=ddpm,
# ======================================
sde=sde,
)
# if sde:
# block_channels = [nf * x for x in ch_mult]
# in_channels = out_channels = num_channels
# conv_resample = resamp_with_conv
# time_embedding_type = "fourier"
# self.config.time_embedding_type = time_embedding_type
# self.config.resnet_eps = 1e-6
# self.config.mid_block_scale_factor = math.sqrt(2.0)
# self.config.resnet_num_groups = None
# down_blocks = (
# "UNetResSkipDownBlock2D",
# "UNetResAttnSkipDownBlock2D",
# "UNetResSkipDownBlock2D",
# "UNetResSkipDownBlock2D",
# )
# up_blocks = (
# "UNetResSkipUpBlock2D",
# "UNetResSkipUpBlock2D",
# "UNetResAttnSkipUpBlock2D",
# "UNetResSkipUpBlock2D",
# )
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
......@@ -122,11 +211,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
time_embed_dim = block_channels[0] * 4
# ======================================
# # input
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# # time
self.time_embedding = TimestepEmbedding(block_channels[0], time_embed_dim)
# time
if time_embedding_type == "fourier":
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=fourier_scale)
timestep_input_dim = 2 * block_channels[0]
elif time_embedding_type == "positional":
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
......@@ -154,15 +250,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.downsample_blocks.append(down_block)
# mid
if self.config.ddpm:
if ddpm:
self.mid_new_2 = UNetMidBlock2D(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
)
else:
self.mid = UNetMidBlock2D(
......@@ -171,8 +269,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
)
# up
......@@ -201,16 +301,19 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=32, eps=resnet_eps)
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.is_overwritten = False
if ldm:
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
......@@ -234,7 +337,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
conv_resample,
out_channels,
)
if ddpm:
elif ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
out_ch = out_channels
resolution = image_size
ch = block_channels[0]
......@@ -251,7 +358,54 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_ch,
dropout=0.1,
)
# ======================================
elif sde:
nf = block_channels[0]
ch_mult = [x // nf for x in block_channels]
num_channels = in_channels
# in_channels = out_channels = num_channels = in_channels
# block_channels = [nf * x for x in ch_mult]
# conv_resample = resamp_with_conv
resamp_with_conv = conv_resample
time_embedding_type = self.config.time_embedding_type
# time_embedding_type = "fourier"
# self.config.time_embedding_type = time_embedding_type
fir = True
progressive = "output_skip"
progressive_combine = "sum"
scale_by_sigma = True
skip_rescale = True
centered = False
conditional = True
conv_size = 3
fir_kernel = (1, 3, 3, 1)
fourier_scale = 16
init_scale = 0.0
progressive_input = "input_skip"
continuous = True
self.init_for_sde(
image_size,
num_channels,
centered,
attn_resolutions,
ch_mult,
conditional,
conv_size,
dropout,
time_embedding_type,
fir,
fir_kernel,
fourier_scale,
init_scale,
nf,
num_res_blocks,
progressive,
progressive_combine,
progressive_input,
resamp_with_conv,
scale_by_sigma,
skip_rescale,
continuous,
)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
......@@ -261,54 +415,70 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ======================================
if not self.is_overwritten:
self.set_weights()
# ======================================
# 1. time step embeddings -> make correct tensor
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = get_timestep_embedding(
timesteps,
self.config.block_channels[0],
flip_sin_to_cos=self.config.flip_sin_to_cos,
downscale_freq_shift=self.config.downscale_freq_shift,
)
t_emb = self.time_steps(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process sample
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down blocks
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
sample, res_samples = downsample_block(sample, emb)
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
# append to tuple
down_block_res_samples += res_samples
# 4. mid block
# 4. mid
if self.config.ddpm:
sample = self.mid_new_2(sample, emb)
else:
sample = self.mid(sample, emb)
# 5. up blocks
# 5. up
skip_sample = None
for upsample_block in self.upsample_blocks:
# pop from tuple
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb)
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
# 6. post-process sample
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if (
self.config.time_embedding_type == "fourier"
or self.time_steps.__class__.__name__ == "GaussianFourierProjection"
):
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
return output
......@@ -319,7 +489,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def set_weights(self):
self.is_overwritten = True
if self.config.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
......@@ -373,8 +542,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.remove_ldm()
elif self.config.ddpm:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data
self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data
self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data
......@@ -411,6 +578,73 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.conv_norm_out.bias.data = self.norm_out.bias.data
self.remove_ddpm()
elif self.config.sde:
self.time_steps.weight = self.all_modules[0].weight
self.time_embedding.linear_1.weight.data = self.all_modules[1].weight.data
self.time_embedding.linear_1.bias.data = self.all_modules[1].bias.data
self.time_embedding.linear_2.weight.data = self.all_modules[2].weight.data
self.time_embedding.linear_2.bias.data = self.all_modules[2].bias.data
self.conv_in.weight.data = self.all_modules[3].weight.data
self.conv_in.bias.data = self.all_modules[3].bias.data
module_index = 4
for i, block in enumerate(self.downsample_blocks):
has_attentios = hasattr(block, "attentions")
if has_attentios:
for j in range(len(block.attentions)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
block.attentions[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "downsamplers") and block.downsamplers is not None:
block.resnet_down.set_weight(self.all_modules[module_index])
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data
block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data
module_index += 1
else:
for j in range(len(block.resnets)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "downsamplers") and block.downsamplers is not None:
block.resnet_down.set_weight(self.all_modules[module_index])
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data
block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data
module_index += 1
self.mid.resnets[0].set_weight(self.all_modules[module_index])
module_index += 1
self.mid.attentions[0].set_weight(self.all_modules[module_index])
module_index += 1
self.mid.resnets[1].set_weight(self.all_modules[module_index])
module_index += 1
for i, block in enumerate(self.upsample_blocks):
for j in range(len(block.resnets)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "attentions") and block.attentions is not None:
block.attentions[0].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "resnet_up") and block.resnet_up is not None:
block.skip_norm.weight.data = self.all_modules[module_index].weight.data
block.skip_norm.bias.data = self.all_modules[module_index].bias.data
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].weight.data
block.skip_conv.bias.data = self.all_modules[module_index].bias.data
module_index += 1
block.resnet_up.set_weight(self.all_modules[module_index])
module_index += 1
self.conv_norm_out.weight.data = self.all_modules[module_index].weight.data
self.conv_norm_out.bias.data = self.all_modules[module_index].bias.data
module_index += 1
self.conv_out.weight.data = self.all_modules[module_index].weight.data
self.conv_out.bias.data = self.all_modules[module_index].bias.data
self.remove_sde()
def init_for_ddpm(
self,
......@@ -700,6 +934,240 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
def init_for_sde(
self,
image_size,
num_channels,
centered,
attn_resolutions,
ch_mult,
conditional,
conv_size,
dropout,
embedding_type,
fir,
fir_kernel,
fourier_scale,
init_scale,
nf,
num_res_blocks,
progressive,
progressive_combine,
progressive_input,
resamp_with_conv,
scale_by_sigma,
skip_rescale,
continuous,
):
self.act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
self.conditional = conditional
self.skip_rescale = skip_rescale
self.progressive = progressive
self.progressive_input = progressive_input
self.embedding_type = embedding_type
assert progressive in ["none", "output_skip", "residual"]
assert progressive_input in ["none", "input_skip", "residual"]
assert embedding_type in ["fourier", "positional"]
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == "fourier":
# Gaussian Fourier features embeddings.
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
embed_dim = 2 * nf
elif embedding_type == "positional":
embed_dim = nf
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if fir:
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if fir:
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir" if fir else "sde_vp",
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == "cat":
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
# mid
in_ch = hs_c[-1]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != "none":
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir" if fir else "sde_vp",
use_nin_shortcut=True,
)
)
assert not hs_c
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules)
def remove_ldm(self):
del self.time_embed
del self.input_blocks
......@@ -714,6 +1182,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del self.up
del self.norm_out
def remove_sde(self):
del self.all_modules
def nonlinearity(x):
# swish
......
......@@ -14,8 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
shape = (1, 3, img_size, img_size)
model = self.model.to(device)
......@@ -34,11 +33,18 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for _ in range(n_steps):
with torch.no_grad():
result = self.model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x, x_mean = self.scheduler.step_pred(result, x, t)
return x_mean
......@@ -15,6 +15,7 @@
import inspect
import math
import tempfile
import unittest
......@@ -590,7 +591,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = NCSNpp
model_class = UNetUnconditionalModel
@property
def dummy_input(self):
......@@ -613,22 +614,34 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"image_size": 32,
"ch_mult": [1, 2, 2, 2],
"nf": 32,
"fir": True,
"progressive": "output_skip",
"progressive_combine": "sum",
"progressive_input": "input_skip",
"scale_by_sigma": True,
"skip_rescale": True,
"embedding_type": "fourier",
"block_channels": [32, 64, 64, 64],
"in_channels": 3,
"num_res_blocks": 1,
"out_channels": 3,
"time_embedding_type": "fourier",
"resnet_eps": 1e-6,
"mid_block_scale_factor": math.sqrt(2.0),
"resnet_num_groups": None,
"down_blocks": [
"UNetResSkipDownBlock2D",
"UNetResAttnSkipDownBlock2D",
"UNetResSkipDownBlock2D",
"UNetResSkipDownBlock2D",
],
"up_blocks": [
"UNetResSkipUpBlock2D",
"UNetResSkipUpBlock2D",
"UNetResAttnSkipUpBlock2D",
"UNetResSkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = NCSNpp.from_pretrained("fusing/cifar10-ncsnpp-ve", output_loading_info=True)
model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ncsnpp-ffhq-ve-dummy", sde=True, output_loading_info=True
)
self.assertIsNotNone(model)
# self.assertEqual(len(loading_info["missing_keys"]), 0)
......@@ -663,9 +676,33 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_mid(self):
model = UNetUnconditionalModel.from_pretrained("fusing/celebahq_256-ncsnpp-ve", sde=True)
model.to(torch_device)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
batch_size = 4
num_channels = 3
sizes = (256, 256)
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self):
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
model.eval()
model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy", sde=True)
model.to(torch_device)
torch.manual_seed(0)
......@@ -680,7 +717,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)
output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
......@@ -691,7 +728,6 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained_vp(self):
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
model.eval()
model.to(torch_device)
torch.manual_seed(0)
......@@ -874,7 +910,6 @@ class PipelineTesterMixin(unittest.TestCase):
out_channels=3,
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
ddpm=True,
)
schedular = DDPMScheduler(timesteps=10)
......@@ -1038,7 +1073,12 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_score_sde_ve_pipeline(self):
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
......@@ -1047,11 +1087,11 @@ class PipelineTesterMixin(unittest.TestCase):
image = sde_ve(num_inference_steps=2)
if model.device.type == "cpu":
expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
expected_image_sum = 3384805632.0
expected_image_mean = 1076.000732421875
else:
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788
expected_image_mean = 1075.3787841796875
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment