Unverified Commit ba3c9a9a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDE] Merge to unconditional model (#89)

* up

* more

* uP

* make dummy test pass

* save intermediate

* p

* p

* finish

* finish

* finish
parent b5c684f0
...@@ -100,7 +100,7 @@ def test_output_pretrained_ldm(): ...@@ -100,7 +100,7 @@ def test_output_pretrained_ldm():
# 2. DDPM # 2. DDPM
def get_model(model_id): def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True) model = UNetUnconditionalModel.from_pretrained(model_id, ldm=True)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
...@@ -123,3 +123,16 @@ def get_model(model_id): ...@@ -123,3 +123,16 @@ def get_model(model_id):
# e.g. # e.g.
get_model("fusing/ddpm-cifar10") get_model("fusing/ddpm-cifar10")
# 3. NCSNpp
# Repos to convert and port to google (part of https://github.com/yang-song/score_sde)
# - https://huggingface.co/fusing/ffhq_ncsnpp
# - https://huggingface.co/fusing/church_256-ncsnpp-ve
# - https://huggingface.co/fusing/celebahq_256-ncsnpp-ve
# - https://huggingface.co/fusing/bedroom_256-ncsnpp-ve
# - https://huggingface.co/fusing/ffhq_256-ncsnpp-ve
# tests to make sure to pass
# - test_score_sde_ve_pipeline (in PipelineTesterMixin)
# - test_output_pretrained_ve_mid, test_output_pretrained_ve_large (in NCSNppModelTests)
...@@ -6,166 +6,6 @@ import torch.nn.functional as F ...@@ -6,166 +6,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x, encoder_states=None):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
class AttentionBlockNew(nn.Module): class AttentionBlockNew(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...@@ -216,6 +56,7 @@ class AttentionBlockNew(nn.Module): ...@@ -216,6 +56,7 @@ class AttentionBlockNew(nn.Module):
# norm # norm
hidden_states = self.group_norm(hidden_states) hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v # proj to q, k, v
...@@ -229,9 +70,9 @@ class AttentionBlockNew(nn.Module): ...@@ -229,9 +70,9 @@ class AttentionBlockNew(nn.Module):
value_states = self.transpose_for_scores(value_proj) value_states = self.transpose_for_scores(value_proj)
# get scores # get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output # compute attention output
context_states = torch.matmul(attention_probs, value_states) context_states = torch.matmul(attention_probs, value_states)
...@@ -263,6 +104,20 @@ class AttentionBlockNew(nn.Module): ...@@ -263,6 +104,20 @@ class AttentionBlockNew(nn.Module):
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0] self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T
self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else: else:
qkv_weight = attn_layer.qkv.weight.data.reshape( qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels self.num_heads, 3 * self.channels // self.num_heads, self.channels
...@@ -452,3 +307,137 @@ class GEGLU(nn.Module): ...@@ -452,3 +307,137 @@ class GEGLU(nn.Module):
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate) return x * F.gelu(gate)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
...@@ -54,14 +54,20 @@ def get_timestep_embedding( ...@@ -54,14 +54,20 @@ def get_timestep_embedding(
return emb return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels.""" """Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0): def __init__(self, embedding_size=256, scale=1.0):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x): def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi x = torch.log(x)
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
...@@ -87,96 +87,15 @@ class Downsample2D(nn.Module): ...@@ -87,96 +87,15 @@ class Downsample2D(nn.Module):
self.conv = conv self.conv = conv
def forward(self, x): def forward(self, x):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
# print("x", x.abs().sum())
self.hey = x
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
x = self.conv(x) x = self.conv(x)
self.yas = x
# print("x", x.abs().sum())
return x return x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module): class FirUpsample2D(nn.Module):
...@@ -330,15 +249,137 @@ class FirDownsample2D(nn.Module): ...@@ -330,15 +249,137 @@ class FirDownsample2D(nn.Module):
return x return x
# TODO (patil-suraj): needs test class ResnetBlock(nn.Module):
# class Upsample2D1d(nn.Module): def __init__(
# def __init__(self, dim): self,
# super().__init__() *,
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) in_channels,
# out_channels=None,
# def forward(self, x): conv_shortcut=False,
# return self.conv(x) dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, hey=False):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + h) / self.output_scale_factor
return out
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# THE FOLLOWING SHOULD BE DELETED ONCE ALL CHECKPOITNS ARE CONVERTED
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now! # => All 2D-Resnets are included here now!
...@@ -502,6 +543,7 @@ class ResnetBlock2D(nn.Module): ...@@ -502,6 +543,7 @@ class ResnetBlock2D(nn.Module):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.set_weights_score_vde()
def set_weights_grad_tts(self): def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data self.conv1.weight.data = self.block1.block[0].weight.data
...@@ -559,17 +601,21 @@ class ResnetBlock2D(nn.Module): ...@@ -559,17 +601,21 @@ class ResnetBlock2D(nn.Module):
self.nin_shortcut.weight.data = self.Conv_2.weight.data self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0): def forward(self, x, temb, hey=False, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes # TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements # too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten: # elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde() # self.set_weights_score_vde()
self.is_overwritten = True # self.is_overwritten = True
# h2 tensor(110029.2109)
# h3 tensor(49596.9492)
h = x h = x
h = h * mask h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
...@@ -619,154 +665,9 @@ class ResnetBlock2D(nn.Module): ...@@ -619,154 +665,9 @@ class ResnetBlock2D(nn.Module):
if self.nin_shortcut is not None: if self.nin_shortcut is not None:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return (x + h) / self.output_scale_factor out = (x + h) / self.output_scale_factor
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if self.pre_norm:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
return (x + h) / self.output_scale_factor
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data return out
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# TODO(Patrick) - just there to convert the weights; can delete afterward # TODO(Patrick) - just there to convert the weights; can delete afterward
...@@ -778,39 +679,6 @@ class Block(torch.nn.Module): ...@@ -778,39 +679,6 @@ class Block(torch.nn.Module):
) )
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules # HELPER Modules
......
...@@ -11,12 +11,14 @@ ...@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import numpy as np
# limitations under the License. # limitations under the License.
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionBlockNew from .attention import AttentionBlockNew
from .resnet import Downsample2D, ResnetBlock, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
def get_down_block( def get_down_block(
...@@ -54,6 +56,29 @@ def get_down_block( ...@@ -54,6 +56,29 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
def get_up_block( def get_up_block(
...@@ -91,6 +116,30 @@ def get_up_block( ...@@ -91,6 +116,30 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "UNetResSkipUpBlock2D":
return UNetResSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock2D(nn.Module): class UNetMidBlock2D(nn.Module):
...@@ -113,6 +162,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -113,6 +162,7 @@ class UNetMidBlock2D(nn.Module):
super().__init__() super().__init__()
self.attention_type = attention_type self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
...@@ -138,6 +188,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -138,6 +188,7 @@ class UNetMidBlock2D(nn.Module):
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
num_groups=resnet_groups,
) )
) )
resnets.append( resnets.append(
...@@ -160,7 +211,6 @@ class UNetMidBlock2D(nn.Module): ...@@ -160,7 +211,6 @@ class UNetMidBlock2D(nn.Module):
def forward(self, hidden_states, temb=None, encoder_states=None): def forward(self, hidden_states, temb=None, encoder_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.attention_type == "default": if self.attention_type == "default":
hidden_states = attn(hidden_states) hidden_states = attn(hidden_states)
...@@ -318,6 +368,178 @@ class UNetResDownBlock2D(nn.Module): ...@@ -318,6 +368,178 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
downsample_padding=1,
add_downsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_downsample=True,
downsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(in_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_downsample:
self.resnet_down = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
hidden_states = self.skip_conv(skip_sample) + hidden_states
output_states += (hidden_states,)
return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module): class UNetResAttnUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -457,3 +679,213 @@ class UNetResUpBlock2D(nn.Module): ...@@ -457,3 +679,213 @@ class UNetResUpBlock2D(nn.Module):
hidden_states = upsampler(hidden_states) hidden_states = upsampler(hidden_states)
return hidden_states return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
upsample_padding=1,
add_upsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(resnet_in_channels + res_skip_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = self.attentions[0](hidden_states)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0),
add_upsample=True,
upsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=min(out_channels // 4, 32),
groups_out=min(out_channels // 4, 32),
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
up=True,
kernel="fir",
)
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
self.act = nn.SiLU()
else:
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
else:
skip_sample = 0
if self.resnet_up is not None:
skip_sample_states = self.skip_norm(hidden_states)
skip_sample_states = self.act(skip_sample_states)
skip_sample_states = self.skip_conv(skip_sample_states)
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde=True, overwrite_for_score_vde=True,
) )
) )
self.mid.resnets[0] = modules[len(modules) - 3] # self.mid.resnets[0] = modules[len(modules) - 3]
self.mid.attentions[0] = modules[len(modules) - 2] # self.mid.attentions[0] = modules[len(modules) - 2]
self.mid.resnets[1] = modules[len(modules) - 1] # self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0 pyramid_ch = 0
# Upsampling block # Upsampling block
...@@ -282,22 +282,22 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -282,22 +282,22 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": # elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) # modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) # modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch # pyramid_ch = in_ch
else: # else:
raise ValueError(f"{progressive} is not a valid name.") # raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1)) modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": # elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) # modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch # pyramid_ch = in_ch
else: # else:
raise ValueError(f"{progressive} is not a valid name") # raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
modules.append( modules.append(
...@@ -332,7 +332,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -332,7 +332,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if self.embedding_type == "fourier": if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings. # Gaussian Fourier features embeddings.
used_sigmas = timesteps used_sigmas = timesteps
temb = modules[m_idx](torch.log(used_sigmas)) temb = modules[m_idx](used_sigmas)
m_idx += 1 m_idx += 1
elif self.embedding_type == "positional": elif self.embedding_type == "positional":
...@@ -363,6 +363,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -363,6 +363,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs = [modules[m_idx](x)] hs = [modules[m_idx](x)]
m_idx += 1 m_idx += 1
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
# Residual blocks for this resolution # Residual blocks for this resolution
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
...@@ -394,16 +395,13 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -394,16 +395,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h) hs.append(h)
# h = hs[-1] h = hs[-1]
# h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
# m_idx += 1 m_idx += 1
# h = modules[m_idx](h) h = modules[m_idx](h)
# m_idx += 1 m_idx += 1
# h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
# m_idx += 1 m_idx += 1
h = self.mid(h, temb)
m_idx += 3
pyramid = None pyramid = None
...@@ -424,31 +422,32 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -424,31 +422,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
m_idx += 1 m_idx += 1
pyramid = modules[m_idx](pyramid) pyramid = modules[m_idx](pyramid)
m_idx += 1 m_idx += 1
elif self.progressive == "residual": # elif self.progressive == "residual":
pyramid = self.act(modules[m_idx](h)) # pyramid = self.act(modules[m_idx](h))
m_idx += 1 # m_idx += 1
pyramid = modules[m_idx](pyramid) # pyramid = modules[m_idx](pyramid)
m_idx += 1 # m_idx += 1
else: # else:
raise ValueError(f"{self.progressive} is not a valid name.") # raise ValueError(f"{self.progressive} is not a valid name.")
else: else:
if self.progressive == "output_skip": if self.progressive == "output_skip":
pyramid = self.pyramid_upsample(pyramid)
pyramid_h = self.act(modules[m_idx](h)) pyramid_h = self.act(modules[m_idx](h))
m_idx += 1 m_idx += 1
pyramid_h = modules[m_idx](pyramid_h) pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1 m_idx += 1
pyramid = pyramid + pyramid_h
elif self.progressive == "residual": skip_sample = self.pyramid_upsample(pyramid)
pyramid = modules[m_idx](pyramid) pyramid = skip_sample + pyramid_h
m_idx += 1 # elif self.progressive == "residual":
if self.skip_rescale: # pyramid = modules[m_idx](pyramid)
pyramid = (pyramid + h) / np.sqrt(2.0) # m_idx += 1
else: # if self.skip_rescale:
pyramid = pyramid + h # pyramid = (pyramid + h) / np.sqrt(2.0)
h = pyramid # else:
else: # pyramid = pyramid + h
raise ValueError(f"{self.progressive} is not a valid name") # h = pyramid
# else:
# raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0: if i_level != 0:
h = modules[m_idx](h, temb) h = modules[m_idx](h, temb)
......
import functools
import math
from typing import Dict, Union from typing import Dict, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D, get_down_block, get_up_block from .unet_new import UNetMidBlock2D, get_down_block, get_up_block
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim): def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim) self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = nn.SiLU() self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample): def forward(self, sample):
sample = self.linear_1(sample) sample = self.linear_1(sample)
sample = self.act(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample) sample = self.linear_2(sample)
return sample return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class UNetUnconditionalModel(ModelMixin, ConfigMixin): class UNetUnconditionalModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
...@@ -72,6 +117,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -72,6 +117,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels=32, num_head_channels=32,
flip_sin_to_cos=True, flip_sin_to_cos=True,
downscale_freq_shift=0, downscale_freq_shift=0,
time_embedding_type="positional",
mid_block_scale_factor=1,
center_input_sample=False,
# TODO(PVP) - to delete later at release # TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API # IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ====================================== # ======================================
...@@ -86,7 +134,24 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -86,7 +134,24 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
ch_mult=None, ch_mult=None,
ch=None, ch=None,
ddpm=False, ddpm=False,
# ====================================== # SDE
sde=False,
nf=None,
fir=None,
progressive=None,
progressive_combine=None,
scale_by_sigma=None,
skip_rescale=None,
num_channels=None,
centered=False,
conditional=True,
conv_size=3,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
progressive_input="input_skip",
resnet_num_groups=32,
continuous=True,
): ):
super().__init__() super().__init__()
# register all __init__ params to be accessible via `self.config.<...>` # register all __init__ params to be accessible via `self.config.<...>`
...@@ -101,19 +166,43 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -101,19 +166,43 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
down_blocks=down_blocks, down_blocks=down_blocks,
up_blocks=up_blocks, up_blocks=up_blocks,
dropout=dropout, dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample, conv_resample=conv_resample,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos, flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift, downscale_freq_shift=downscale_freq_shift,
# TODO(PVP) - to delete later at release time_embedding_type=time_embedding_type,
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
attention_resolutions=attention_resolutions, attention_resolutions=attention_resolutions,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
# to delete later
ldm=ldm, ldm=ldm,
ddpm=ddpm, ddpm=ddpm,
# ====================================== sde=sde,
) )
# if sde:
# block_channels = [nf * x for x in ch_mult]
# in_channels = out_channels = num_channels
# conv_resample = resamp_with_conv
# time_embedding_type = "fourier"
# self.config.time_embedding_type = time_embedding_type
# self.config.resnet_eps = 1e-6
# self.config.mid_block_scale_factor = math.sqrt(2.0)
# self.config.resnet_num_groups = None
# down_blocks = (
# "UNetResSkipDownBlock2D",
# "UNetResAttnSkipDownBlock2D",
# "UNetResSkipDownBlock2D",
# "UNetResSkipDownBlock2D",
# )
# up_blocks = (
# "UNetResSkipUpBlock2D",
# "UNetResSkipUpBlock2D",
# "UNetResAttnSkipUpBlock2D",
# "UNetResSkipUpBlock2D",
# )
# TODO(PVP) - to delete later at release # TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API # IMPORTANT: NOT RELEVANT WHEN REVIEWING API
...@@ -122,11 +211,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -122,11 +211,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
# ====================================== # ======================================
# # input # input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# # time # time
self.time_embedding = TimestepEmbedding(block_channels[0], time_embed_dim) if time_embedding_type == "fourier":
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=fourier_scale)
timestep_input_dim = 2 * block_channels[0]
elif time_embedding_type == "positional":
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([]) self.downsample_blocks = nn.ModuleList([])
self.mid = None self.mid = None
...@@ -154,15 +250,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -154,15 +250,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.downsample_blocks.append(down_block) self.downsample_blocks.append(down_block)
# mid # mid
if self.config.ddpm: if ddpm:
self.mid_new_2 = UNetMidBlock2D( self.mid_new_2 = UNetMidBlock2D(
in_channels=block_channels[-1], in_channels=block_channels[-1],
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels, attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
) )
else: else:
self.mid = UNetMidBlock2D( self.mid = UNetMidBlock2D(
...@@ -171,8 +269,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -171,8 +269,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels, attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
) )
# up # up
...@@ -201,16 +301,19 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -201,16 +301,19 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=32, eps=resnet_eps) num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release # TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API # IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ====================================== # ======================================
self.is_overwritten = False self.is_overwritten = False
if ldm: if ldm:
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1 transformer_depth = 1
context_dim = None context_dim = None
legacy = True legacy = True
...@@ -234,7 +337,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -234,7 +337,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
conv_resample, conv_resample,
out_channels, out_channels,
) )
if ddpm: elif ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
out_ch = out_channels out_ch = out_channels
resolution = image_size resolution = image_size
ch = block_channels[0] ch = block_channels[0]
...@@ -251,7 +358,54 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -251,7 +358,54 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_ch, out_ch,
dropout=0.1, dropout=0.1,
) )
# ====================================== elif sde:
nf = block_channels[0]
ch_mult = [x // nf for x in block_channels]
num_channels = in_channels
# in_channels = out_channels = num_channels = in_channels
# block_channels = [nf * x for x in ch_mult]
# conv_resample = resamp_with_conv
resamp_with_conv = conv_resample
time_embedding_type = self.config.time_embedding_type
# time_embedding_type = "fourier"
# self.config.time_embedding_type = time_embedding_type
fir = True
progressive = "output_skip"
progressive_combine = "sum"
scale_by_sigma = True
skip_rescale = True
centered = False
conditional = True
conv_size = 3
fir_kernel = (1, 3, 3, 1)
fourier_scale = 16
init_scale = 0.0
progressive_input = "input_skip"
continuous = True
self.init_for_sde(
image_size,
num_channels,
centered,
attn_resolutions,
ch_mult,
conditional,
conv_size,
dropout,
time_embedding_type,
fir,
fir_kernel,
fourier_scale,
init_scale,
nf,
num_res_blocks,
progressive,
progressive_combine,
progressive_input,
resamp_with_conv,
scale_by_sigma,
skip_rescale,
continuous,
)
def forward( def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
...@@ -261,54 +415,70 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -261,54 +415,70 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ====================================== # ======================================
if not self.is_overwritten: if not self.is_overwritten:
self.set_weights() self.set_weights()
# ======================================
# 1. time step embeddings -> make correct tensor if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep timesteps = timestep
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
t_emb = get_timestep_embedding( t_emb = self.time_steps(timesteps)
timesteps,
self.config.block_channels[0],
flip_sin_to_cos=self.config.flip_sin_to_cos,
downscale_freq_shift=self.config.downscale_freq_shift,
)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
# 2. pre-process sample # 2. pre-process
skip_sample = sample
sample = self.conv_in(sample) sample = self.conv_in(sample)
# 3. down blocks # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks: for downsample_block in self.downsample_blocks:
sample, res_samples = downsample_block(sample, emb) if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
# append to tuple
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid block # 4. mid
if self.config.ddpm: if self.config.ddpm:
sample = self.mid_new_2(sample, emb) sample = self.mid_new_2(sample, emb)
else: else:
sample = self.mid(sample, emb) sample = self.mid(sample, emb)
# 5. up blocks # 5. up
skip_sample = None
for upsample_block in self.upsample_blocks: for upsample_block in self.upsample_blocks:
# pop from tuple
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb) if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
# 6. post-process sample
sample = self.conv_norm_out(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if (
self.config.time_embedding_type == "fourier"
or self.time_steps.__class__.__name__ == "GaussianFourierProjection"
):
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample} output = {"sample": sample}
return output return output
...@@ -319,7 +489,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -319,7 +489,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def set_weights(self): def set_weights(self):
self.is_overwritten = True self.is_overwritten = True
if self.config.ldm: if self.config.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
...@@ -373,8 +542,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -373,8 +542,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.remove_ldm() self.remove_ldm()
elif self.config.ddpm: elif self.config.ddpm:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data
self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data
self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data
...@@ -411,6 +578,73 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -411,6 +578,73 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.conv_norm_out.bias.data = self.norm_out.bias.data self.conv_norm_out.bias.data = self.norm_out.bias.data
self.remove_ddpm() self.remove_ddpm()
elif self.config.sde:
self.time_steps.weight = self.all_modules[0].weight
self.time_embedding.linear_1.weight.data = self.all_modules[1].weight.data
self.time_embedding.linear_1.bias.data = self.all_modules[1].bias.data
self.time_embedding.linear_2.weight.data = self.all_modules[2].weight.data
self.time_embedding.linear_2.bias.data = self.all_modules[2].bias.data
self.conv_in.weight.data = self.all_modules[3].weight.data
self.conv_in.bias.data = self.all_modules[3].bias.data
module_index = 4
for i, block in enumerate(self.downsample_blocks):
has_attentios = hasattr(block, "attentions")
if has_attentios:
for j in range(len(block.attentions)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
block.attentions[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "downsamplers") and block.downsamplers is not None:
block.resnet_down.set_weight(self.all_modules[module_index])
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data
block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data
module_index += 1
else:
for j in range(len(block.resnets)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "downsamplers") and block.downsamplers is not None:
block.resnet_down.set_weight(self.all_modules[module_index])
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data
block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data
module_index += 1
self.mid.resnets[0].set_weight(self.all_modules[module_index])
module_index += 1
self.mid.attentions[0].set_weight(self.all_modules[module_index])
module_index += 1
self.mid.resnets[1].set_weight(self.all_modules[module_index])
module_index += 1
for i, block in enumerate(self.upsample_blocks):
for j in range(len(block.resnets)):
block.resnets[j].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "attentions") and block.attentions is not None:
block.attentions[0].set_weight(self.all_modules[module_index])
module_index += 1
if hasattr(block, "resnet_up") and block.resnet_up is not None:
block.skip_norm.weight.data = self.all_modules[module_index].weight.data
block.skip_norm.bias.data = self.all_modules[module_index].bias.data
module_index += 1
block.skip_conv.weight.data = self.all_modules[module_index].weight.data
block.skip_conv.bias.data = self.all_modules[module_index].bias.data
module_index += 1
block.resnet_up.set_weight(self.all_modules[module_index])
module_index += 1
self.conv_norm_out.weight.data = self.all_modules[module_index].weight.data
self.conv_norm_out.bias.data = self.all_modules[module_index].bias.data
module_index += 1
self.conv_out.weight.data = self.all_modules[module_index].weight.data
self.conv_out.bias.data = self.all_modules[module_index].bias.data
self.remove_sde()
def init_for_ddpm( def init_for_ddpm(
self, self,
...@@ -700,6 +934,240 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -700,6 +934,240 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
nn.Conv2d(model_channels, out_channels, 3, padding=1), nn.Conv2d(model_channels, out_channels, 3, padding=1),
) )
def init_for_sde(
self,
image_size,
num_channels,
centered,
attn_resolutions,
ch_mult,
conditional,
conv_size,
dropout,
embedding_type,
fir,
fir_kernel,
fourier_scale,
init_scale,
nf,
num_res_blocks,
progressive,
progressive_combine,
progressive_input,
resamp_with_conv,
scale_by_sigma,
skip_rescale,
continuous,
):
self.act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
self.conditional = conditional
self.skip_rescale = skip_rescale
self.progressive = progressive
self.progressive_input = progressive_input
self.embedding_type = embedding_type
assert progressive in ["none", "output_skip", "residual"]
assert progressive_input in ["none", "input_skip", "residual"]
assert embedding_type in ["fourier", "positional"]
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == "fourier":
# Gaussian Fourier features embeddings.
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
embed_dim = 2 * nf
elif embedding_type == "positional":
embed_dim = nf
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if fir:
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if fir:
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir" if fir else "sde_vp",
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == "cat":
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
# mid
in_ch = hs_c[-1]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != "none":
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir" if fir else "sde_vp",
use_nin_shortcut=True,
)
)
assert not hs_c
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules)
def remove_ldm(self): def remove_ldm(self):
del self.time_embed del self.time_embed
del self.input_blocks del self.input_blocks
...@@ -714,6 +1182,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -714,6 +1182,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del self.up del self.up
del self.norm_out del self.norm_out
def remove_sde(self):
del self.all_modules
def nonlinearity(x): def nonlinearity(x):
# swish # swish
......
...@@ -14,8 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -14,8 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
channels = self.model.config.num_channels shape = (1, 3, img_size, img_size)
shape = (1, channels, img_size, img_size)
model = self.model.to(device) model = self.model.to(device)
...@@ -34,11 +33,18 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -34,11 +33,18 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for _ in range(n_steps): for _ in range(n_steps):
with torch.no_grad(): with torch.no_grad():
result = self.model(x, sigma_t) result = self.model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x = self.scheduler.step_correct(result, x) x = self.scheduler.step_correct(result, x)
with torch.no_grad(): with torch.no_grad():
result = model(x, sigma_t) result = model(x, sigma_t)
if isinstance(result, dict):
result = result["sample"]
x, x_mean = self.scheduler.step_pred(result, x, t) x, x_mean = self.scheduler.step_pred(result, x, t)
return x_mean return x_mean
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import inspect import inspect
import math
import tempfile import tempfile
import unittest import unittest
...@@ -590,7 +591,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -590,7 +591,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = NCSNpp model_class = UNetUnconditionalModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -613,22 +614,34 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -613,22 +614,34 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"image_size": 32, "block_channels": [32, 64, 64, 64],
"ch_mult": [1, 2, 2, 2], "in_channels": 3,
"nf": 32, "num_res_blocks": 1,
"fir": True, "out_channels": 3,
"progressive": "output_skip", "time_embedding_type": "fourier",
"progressive_combine": "sum", "resnet_eps": 1e-6,
"progressive_input": "input_skip", "mid_block_scale_factor": math.sqrt(2.0),
"scale_by_sigma": True, "resnet_num_groups": None,
"skip_rescale": True, "down_blocks": [
"embedding_type": "fourier", "UNetResSkipDownBlock2D",
"UNetResAttnSkipDownBlock2D",
"UNetResSkipDownBlock2D",
"UNetResSkipDownBlock2D",
],
"up_blocks": [
"UNetResSkipUpBlock2D",
"UNetResSkipUpBlock2D",
"UNetResAttnSkipUpBlock2D",
"UNetResSkipUpBlock2D",
],
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = NCSNpp.from_pretrained("fusing/cifar10-ncsnpp-ve", output_loading_info=True) model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ncsnpp-ffhq-ve-dummy", sde=True, output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
# self.assertEqual(len(loading_info["missing_keys"]), 0) # self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -663,9 +676,33 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -663,9 +676,33 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_mid(self):
model = UNetUnconditionalModel.from_pretrained("fusing/celebahq_256-ncsnpp-ve", sde=True)
model.to(torch_device)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
batch_size = 4
num_channels = 3
sizes = (256, 256)
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self): def test_output_pretrained_ve_large(self):
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy") model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy", sde=True)
model.eval()
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -680,7 +717,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -680,7 +717,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)["sample"]
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
...@@ -691,7 +728,6 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -691,7 +728,6 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained_vp(self): def test_output_pretrained_vp(self):
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
model.eval()
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -874,7 +910,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -874,7 +910,6 @@ class PipelineTesterMixin(unittest.TestCase):
out_channels=3, out_channels=3,
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
ddpm=True,
) )
schedular = DDPMScheduler(timesteps=10) schedular = DDPMScheduler(timesteps=10)
...@@ -1038,7 +1073,12 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1038,7 +1073,12 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
...@@ -1047,11 +1087,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1047,11 +1087,11 @@ class PipelineTesterMixin(unittest.TestCase):
image = sde_ve(num_inference_steps=2) image = sde_ve(num_inference_steps=2)
if model.device.type == "cpu": if model.device.type == "cpu":
expected_image_sum = 3384805888.0 expected_image_sum = 3384805632.0
expected_image_mean = 1076.00085 expected_image_mean = 1076.000732421875
else: else:
expected_image_sum = 3382849024.0 expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788 expected_image_mean = 1075.3787841796875
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment