"docs/vscode:/vscode.git/clone" did not exist on "e45489b1baf76c7127d65a43f44a2c269daf2ae6"
Unverified Commit ba3c9a9a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDE] Merge to unconditional model (#89)

* up

* more

* uP

* make dummy test pass

* save intermediate

* p

* p

* finish

* finish

* finish
parent b5c684f0
...@@ -100,7 +100,7 @@ def test_output_pretrained_ldm(): ...@@ -100,7 +100,7 @@ def test_output_pretrained_ldm():
# 2. DDPM # 2. DDPM
def get_model(model_id): def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True) model = UNetUnconditionalModel.from_pretrained(model_id, ldm=True)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
...@@ -123,3 +123,16 @@ def get_model(model_id): ...@@ -123,3 +123,16 @@ def get_model(model_id):
# e.g. # e.g.
get_model("fusing/ddpm-cifar10") get_model("fusing/ddpm-cifar10")
# 3. NCSNpp
# Repos to convert and port to google (part of https://github.com/yang-song/score_sde)
# - https://huggingface.co/fusing/ffhq_ncsnpp
# - https://huggingface.co/fusing/church_256-ncsnpp-ve
# - https://huggingface.co/fusing/celebahq_256-ncsnpp-ve
# - https://huggingface.co/fusing/bedroom_256-ncsnpp-ve
# - https://huggingface.co/fusing/ffhq_256-ncsnpp-ve
# tests to make sure to pass
# - test_score_sde_ve_pipeline (in PipelineTesterMixin)
# - test_output_pretrained_ve_mid, test_output_pretrained_ve_large (in NCSNppModelTests)
...@@ -6,166 +6,6 @@ import torch.nn.functional as F ...@@ -6,166 +6,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x, encoder_states=None):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
class AttentionBlockNew(nn.Module): class AttentionBlockNew(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...@@ -216,6 +56,7 @@ class AttentionBlockNew(nn.Module): ...@@ -216,6 +56,7 @@ class AttentionBlockNew(nn.Module):
# norm # norm
hidden_states = self.group_norm(hidden_states) hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v # proj to q, k, v
...@@ -229,9 +70,9 @@ class AttentionBlockNew(nn.Module): ...@@ -229,9 +70,9 @@ class AttentionBlockNew(nn.Module):
value_states = self.transpose_for_scores(value_proj) value_states = self.transpose_for_scores(value_proj)
# get scores # get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output # compute attention output
context_states = torch.matmul(attention_probs, value_states) context_states = torch.matmul(attention_probs, value_states)
...@@ -263,6 +104,20 @@ class AttentionBlockNew(nn.Module): ...@@ -263,6 +104,20 @@ class AttentionBlockNew(nn.Module):
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0] self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T
self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else: else:
qkv_weight = attn_layer.qkv.weight.data.reshape( qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels self.num_heads, 3 * self.channels // self.num_heads, self.channels
...@@ -452,3 +307,137 @@ class GEGLU(nn.Module): ...@@ -452,3 +307,137 @@ class GEGLU(nn.Module):
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate) return x * F.gelu(gate)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
...@@ -54,14 +54,20 @@ def get_timestep_embedding( ...@@ -54,14 +54,20 @@ def get_timestep_embedding(
return emb return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels.""" """Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0): def __init__(self, embedding_size=256, scale=1.0):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x): def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi x = torch.log(x)
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
...@@ -87,96 +87,15 @@ class Downsample2D(nn.Module): ...@@ -87,96 +87,15 @@ class Downsample2D(nn.Module):
self.conv = conv self.conv = conv
def forward(self, x): def forward(self, x):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
# print("x", x.abs().sum())
self.hey = x
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
x = self.conv(x) x = self.conv(x)
self.yas = x
# print("x", x.abs().sum())
return x return x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module): class FirUpsample2D(nn.Module):
...@@ -330,15 +249,137 @@ class FirDownsample2D(nn.Module): ...@@ -330,15 +249,137 @@ class FirDownsample2D(nn.Module):
return x return x
# TODO (patil-suraj): needs test class ResnetBlock(nn.Module):
# class Upsample2D1d(nn.Module): def __init__(
# def __init__(self, dim): self,
# super().__init__() *,
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) in_channels,
# out_channels=None,
# def forward(self, x): conv_shortcut=False,
# return self.conv(x) dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, hey=False):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + h) / self.output_scale_factor
return out
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# THE FOLLOWING SHOULD BE DELETED ONCE ALL CHECKPOITNS ARE CONVERTED
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now! # => All 2D-Resnets are included here now!
...@@ -502,6 +543,7 @@ class ResnetBlock2D(nn.Module): ...@@ -502,6 +543,7 @@ class ResnetBlock2D(nn.Module):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.set_weights_score_vde()
def set_weights_grad_tts(self): def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data self.conv1.weight.data = self.block1.block[0].weight.data
...@@ -559,17 +601,21 @@ class ResnetBlock2D(nn.Module): ...@@ -559,17 +601,21 @@ class ResnetBlock2D(nn.Module):
self.nin_shortcut.weight.data = self.Conv_2.weight.data self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0): def forward(self, x, temb, hey=False, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes # TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements # too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten: # elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde() # self.set_weights_score_vde()
self.is_overwritten = True # self.is_overwritten = True
# h2 tensor(110029.2109)
# h3 tensor(49596.9492)
h = x h = x
h = h * mask h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
...@@ -619,154 +665,9 @@ class ResnetBlock2D(nn.Module): ...@@ -619,154 +665,9 @@ class ResnetBlock2D(nn.Module):
if self.nin_shortcut is not None: if self.nin_shortcut is not None:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return (x + h) / self.output_scale_factor out = (x + h) / self.output_scale_factor
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if self.pre_norm:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) return out
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
return (x + h) / self.output_scale_factor
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# TODO(Patrick) - just there to convert the weights; can delete afterward # TODO(Patrick) - just there to convert the weights; can delete afterward
...@@ -778,39 +679,6 @@ class Block(torch.nn.Module): ...@@ -778,39 +679,6 @@ class Block(torch.nn.Module):
) )
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules # HELPER Modules
......
...@@ -11,12 +11,14 @@ ...@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import numpy as np
# limitations under the License. # limitations under the License.
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionBlockNew from .attention import AttentionBlockNew
from .resnet import Downsample2D, ResnetBlock, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
def get_down_block( def get_down_block(
...@@ -54,6 +56,29 @@ def get_down_block( ...@@ -54,6 +56,29 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
def get_up_block( def get_up_block(
...@@ -91,6 +116,30 @@ def get_up_block( ...@@ -91,6 +116,30 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "UNetResSkipUpBlock2D":
return UNetResSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock2D(nn.Module): class UNetMidBlock2D(nn.Module):
...@@ -113,6 +162,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -113,6 +162,7 @@ class UNetMidBlock2D(nn.Module):
super().__init__() super().__init__()
self.attention_type = attention_type self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
...@@ -138,6 +188,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -138,6 +188,7 @@ class UNetMidBlock2D(nn.Module):
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
num_groups=resnet_groups,
) )
) )
resnets.append( resnets.append(
...@@ -160,7 +211,6 @@ class UNetMidBlock2D(nn.Module): ...@@ -160,7 +211,6 @@ class UNetMidBlock2D(nn.Module):
def forward(self, hidden_states, temb=None, encoder_states=None): def forward(self, hidden_states, temb=None, encoder_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.attention_type == "default": if self.attention_type == "default":
hidden_states = attn(hidden_states) hidden_states = attn(hidden_states)
...@@ -318,6 +368,178 @@ class UNetResDownBlock2D(nn.Module): ...@@ -318,6 +368,178 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
downsample_padding=1,
add_downsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_downsample=True,
downsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module): class UNetResAttnUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -457,3 +679,213 @@ class UNetResUpBlock2D(nn.Module): ...@@ -457,3 +679,213 @@ class UNetResUpBlock2D(nn.Module):
hidden_states = upsampler(hidden_states) hidden_states = upsampler(hidden_states)
return hidden_states return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
upsample_padding=1,
add_upsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(resnet_in_channels + res_skip_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = self.attentions[0](hidden_states)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_upsample=True,
upsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde=True, overwrite_for_score_vde=True,
) )
) )
self.mid.resnets[0] = modules[len(modules) - 3] # self.mid.resnets[0] = modules[len(modules) - 3]
self.mid.attentions[0] = modules[len(modules) - 2] # self.mid.attentions[0] = modules[len(modules) - 2]
self.mid.resnets[1] = modules[len(modules) - 1] # self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0 pyramid_ch = 0
# Upsampling block # Upsampling block
...@@ -282,22 +282,22 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -282,22 +282,22 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": # elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) # modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) # modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch # pyramid_ch = in_ch
else: # else:
raise ValueError(f"{progressive} is not a valid name.") # raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": # elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) # modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch # pyramid_ch = in_ch
else: # else:
raise ValueError(f"{progressive} is not a valid name") # raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
modules.append( modules.append(
...@@ -332,7 +332,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -332,7 +332,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if self.embedding_type == "fourier": if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings. # Gaussian Fourier features embeddings.
used_sigmas = timesteps used_sigmas = timesteps
temb = modules[m_idx](torch.log(used_sigmas)) temb = modules[m_idx](used_sigmas)
m_idx += 1 m_idx += 1
elif self.embedding_type == "positional": elif self.embedding_type == "positional":
...@@ -363,6 +363,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -363,6 +363,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs = [modules[m_idx](x)] hs = [modules[m_idx](x)]
m_idx += 1 m_idx += 1
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
# Residual blocks for this resolution # Residual blocks for this resolution
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
...@@ -394,16 +395,13 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -394,16 +395,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h) hs.append(h)
# h = hs[-1] h = hs[-1]
# h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
# m_idx += 1 m_idx += 1
# h = modules[m_idx](h) h = modules[m_idx](h)
# m_idx += 1 m_idx += 1
# h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
# m_idx += 1 m_idx += 1
h = self.mid(h, temb)
m_idx += 3
pyramid = None pyramid = None
...@@ -424,31 +422,32 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -424,31 +422,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
m_idx += 1 m_idx += 1
pyramid = modules[m_idx](pyramid) pyramid = modules[m_idx](pyramid)
m_idx += 1 m_idx += 1
elif self.progressive == "residual": # elif self.progressive == "residual":
pyramid = self.act(modules[m_idx](h)) # pyramid = self.act(modules[m_idx](h))
m_idx += 1 # m_idx += 1
pyramid = modules[m_idx](pyramid) # pyramid = modules[m_idx](pyramid)
m_idx += 1 # m_idx += 1
else: # else:
raise ValueError(f"{self.progressive} is not a valid name.") # raise ValueError(f"{self.progressive} is not a valid name.")
else: else:
if self.progressive == "output_skip": if self.progressive == "output_skip":
pyramid = self.pyramid_upsample(pyramid)
pyramid_h = self.act(modules[m_idx](h)) pyramid_h = self.act(modules[m_idx](h))
m_idx += 1 m_idx += 1
pyramid_h = modules[m_idx](pyramid_h) pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1 m_idx += 1
pyramid = pyramid + pyramid_h
elif self.progressive == "residual": skip_sample = self.pyramid_upsample(pyramid)
pyramid = modules[m_idx](pyramid) pyramid = skip_sample + pyramid_h
m_idx += 1 # elif self.progressive == "residual":
if self.skip_rescale: # pyramid = modules[m_idx](pyramid)
pyramid = (pyramid + h) / np.sqrt(2.0) # m_idx += 1
else: # if self.skip_rescale:
pyramid = pyramid + h # pyramid = (pyramid + h) / np.sqrt(2.0)
h = pyramid # else:
else: # pyramid = pyramid + h
raise ValueError(f"{self.progressive} is not a valid name") # h = pyramid
# else:
# raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0: if i_level != 0:
h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
......
This diff is collapsed.
...@@ -14,8 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -14,8 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
channels = self.model.config.num_channels shape = (1, 3, img_size, img_size)
shape = (1, channels, img_size, img_size)
model = self.model.to(device) model = self.model.to(device)
...@@ -34,11 +33,18 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -34,11 +33,18 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for _ in range(n_steps): for _ in range(n_steps):
with torch.no_grad(): with torch.no_grad():
result = self.model(x, sigma_t) result = self.model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x = self.scheduler.step_correct(result, x) x = self.scheduler.step_correct(result, x)
with torch.no_grad(): with torch.no_grad():
result = model(x, sigma_t) result = model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x, x_mean = self.scheduler.step_pred(result, x, t) x, x_mean = self.scheduler.step_pred(result, x, t)
return x_mean return x_mean
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import inspect import inspect
import math
import tempfile import tempfile
import unittest import unittest
...@@ -590,7 +591,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -590,7 +591,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = NCSNpp model_class = UNetUnconditionalModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -613,22 +614,34 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -613,22 +614,34 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"image_size": 32, "block_channels": [32, 64, 64, 64],
"ch_mult": [1, 2, 2, 2], "in_channels": 3,
"nf": 32, "num_res_blocks": 1,
"fir": True, "out_channels": 3,
"progressive": "output_skip", "time_embedding_type": "fourier",
"progressive_combine": "sum", "resnet_eps": 1e-6,
"progressive_input": "input_skip", "mid_block_scale_factor": math.sqrt(2.0),
"scale_by_sigma": True, "resnet_num_groups": None,
"skip_rescale": True, "down_blocks": [
"embedding_type": "fourier", "UNetResSkipDownBlock2D",
"UNetResAttnSkipDownBlock2D",
"UNetResSkipDownBlock2D",
"UNetResSkipDownBlock2D",
],
"up_blocks": [
"UNetResSkipUpBlock2D",
"UNetResSkipUpBlock2D",
"UNetResAttnSkipUpBlock2D",
"UNetResSkipUpBlock2D",
],
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = NCSNpp.from_pretrained("fusing/cifar10-ncsnpp-ve", output_loading_info=True) model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ncsnpp-ffhq-ve-dummy", sde=True, output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
# self.assertEqual(len(loading_info["missing_keys"]), 0) # self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -663,9 +676,33 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -663,9 +676,33 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_mid(self):
model = UNetUnconditionalModel.from_pretrained("fusing/celebahq_256-ncsnpp-ve", sde=True)
model.to(torch_device)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
batch_size = 4
num_channels = 3
sizes = (256, 256)
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self): def test_output_pretrained_ve_large(self):
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy") model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy", sde=True)
model.eval()
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -680,7 +717,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -680,7 +717,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
...@@ -691,7 +728,6 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -691,7 +728,6 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained_vp(self): def test_output_pretrained_vp(self):
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
model.eval()
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -874,7 +910,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -874,7 +910,6 @@ class PipelineTesterMixin(unittest.TestCase):
out_channels=3, out_channels=3,
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
ddpm=True,
) )
schedular = DDPMScheduler(timesteps=10) schedular = DDPMScheduler(timesteps=10)
...@@ -1038,7 +1073,12 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1038,7 +1073,12 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
...@@ -1047,11 +1087,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1047,11 +1087,11 @@ class PipelineTesterMixin(unittest.TestCase):
image = sde_ve(num_inference_steps=2) image = sde_ve(num_inference_steps=2)
if model.device.type == "cpu": if model.device.type == "cpu":
expected_image_sum = 3384805888.0 expected_image_sum = 3384805632.0
expected_image_mean = 1076.00085 expected_image_mean = 1076.000732421875
else: else:
expected_image_sum = 3382849024.0 expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788 expected_image_mean = 1075.3787841796875
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment