Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f1aade05
Commit
f1aade05
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
up
parent
abedfb08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1946 additions
and
4 deletions
+1946
-4
@
@
+859
-0
G
G
+883
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+202
-3
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-1
No files found.
@
0 → 100644
View file @
f1aade05
from abc import abstractmethod
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
# RESNETS
# unet_score_estimation.py
class ResnetBlockBigGANppNew(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
overwrite=True,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
if self.overwrite:
in_channels = in_ch
out_channels = out_ch
groups = min(in_ch // 4, 32)
eps = 1e-6
self.pre_norm = True
temb_channels = temb_dim
non_linearity = "silu"
time_embedding_norm = "default"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
def forward_2(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
G
0 → 100644
View file @
f1aade05
from abc import abstractmethod
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
# RESNETS
# unet_score_estimation.py
class ResnetBlockBigGANppNew(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
overwrite=True,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
self.is_overwritten = False
self.overwrite = overwrite
if overwrite:
self.in_channels = in_channels = in_ch
self.out_channels = out_channels = out_ch
groups = min(in_ch // 4, 32)
eps = 1e-6
self.pre_norm = True
temb_channels = temb_dim
non_linearity = "silu"
self.time_embedding_norm = time_embedding_norm = "default"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def set_weights(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def set_weights_ldm(self):
import ipdb; ipdb.set_trace()
def forward(self, x, temb=None):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
orig_x = x
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
result = x + h
else:
result = (x + h) / np.sqrt(2.0)
result_2 = self.forward_2(orig_x, temb)
print("Diff", (result - result_2).abs().sum())
import ipdb; ipdb.set_trace()
return result
def forward_2(self, x, temb, mask=1.0):
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
denominator = (fan_in + fan_out) / 2
variance = scale / denominator
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return init
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
src/diffusers/models/resnet.py
View file @
f1aade05
...
@@ -160,6 +160,208 @@ class Downsample(nn.Module):
...
@@ -160,6 +160,208 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_score_estimation.py
class
ResnetBlockBigGANppNew
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
overwrite
=
True
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
self
.
is_overwritten
=
False
self
.
overwrite
=
overwrite
if
overwrite
:
self
.
output_scale_factor
=
np
.
sqrt
(
2.0
)
self
.
in_channels
=
in_channels
=
in_ch
self
.
out_channels
=
out_channels
=
out_ch
groups
=
min
(
in_ch
//
4
,
32
)
out_groups
=
min
(
out_ch
//
4
,
32
)
eps
=
1e-6
self
.
pre_norm
=
True
temb_channels
=
temb_dim
non_linearity
=
"silu"
self
.
time_embedding_norm
=
time_embedding_norm
=
"default"
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
out_groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
set_weights
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
Conv_0
.
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
Conv_0
.
bias
.
data
self
.
norm1
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
Conv_1
.
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
Conv_1
.
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
GroupNorm_1
.
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
GroupNorm_1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
Dense_0
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
Dense_0
.
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
Conv_2
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
Conv_2
.
bias
.
data
def
forward
(
self
,
x
,
temb
=
None
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
orig_x
=
x
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
elif
self
.
down
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
raise
ValueError
(
"Is this branch run?!"
)
# import ipdb; ipdb.set_trace()
result
=
x
+
h
else
:
result
=
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
result_2
=
self
.
forward_2
(
orig_x
,
temb
)
return
result_2
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
h
=
x
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
# if self.up or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
if
self
.
up
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
elif
self
.
down
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
else
:
raise
ValueError
(
"Nananan nanana - don't go here!"
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
x
=
x
*
mask
# if self.in_channels != self.out_channels:
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
x
=
self
.
nin_shortcut
(
x
)
result
=
x
+
h
return
result
/
self
.
output_scale_factor
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
...
@@ -245,9 +447,6 @@ class ResnetBlock(nn.Module):
...
@@ -245,9 +447,6 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
elif
self
.
overwrite_for_ldm
:
dims
=
2
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
channels
=
in_channels
emb_channels
=
temb_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
use_scale_shift_norm
=
False
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
f1aade05
...
@@ -27,7 +27,8 @@ from ..configuration_utils import ConfigMixin
...
@@ -27,7 +27,8 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
ResnetBlockBigGANppNew
as
ResnetBlockBigGANpp
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment