Unverified Commit 1899457b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Merge pull request #40 from huggingface/start_resnet_unificiation

resnet in one file
parents e5d9baf0 ebf3717c
This diff is collapsed.
......@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .resnet import Downsample, ResnetBlock, Upsample
def nonlinearity(x):
......@@ -34,46 +34,46 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class UNetModel(ModelMixin, ConfigMixin):
......@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
attn_2 = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
......@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down = nn.Module()
down.block = block
down.attn = attn
down.attn_2 = attn_2
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
......
......@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def convert_module_to_f16(l):
......@@ -96,16 +96,14 @@ def zero_module(module):
return module
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
......@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
# class ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels, swish=1.0),
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# """
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
class GlideUNetModel(ModelMixin, ConfigMixin):
......
......@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .resnet import Downsample
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample
class Mish(torch.nn.Module):
......@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return output * mask
class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
# class ResnetBlock(torch.nn.Module):
# def __init__(self, dim, dim_out, time_emb_dim, groups=8):
# super(ResnetBlock, self).__init__()
# self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
# self.block1 = Block(dim, dim_out, groups=groups)
# self.block2 = Block(dim_out, dim_out, groups=groups)
# if dim != dim_out:
# self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
# else:
# self.res_conv = torch.nn.Identity()
#
# def forward(self, x, mask, time_emb):
# h = self.block1(x, mask)
# h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
# h = self.block2(h, mask)
# output = h + self.res_conv(x * mask)
# return output
class Residual(torch.nn.Module):
......
......@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def exists(val):
......@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
......@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
# class A_ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels),
# nn.SiLU(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels),
# nn.SiLU(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
#
class QKVAttention(nn.Module):
......
......@@ -6,6 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import ResidualTemporalBlock
class SinusoidalPosEmb(nn.Module):
......@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# class ResidualTemporalBlock(nn.Module):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
# super().__init__()
#
# self.blocks = nn.ModuleList(
# [
# Conv1dBlock(inp_channels, out_channels, kernel_size),
# Conv1dBlock(out_channels, out_channels, kernel_size),
# ]
# )
#
# self.time_mlp = nn.Sequential(
# nn.Mish(),
# nn.Linear(embed_dim, out_channels),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# )
#
# self.residual_conv = (
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
# )
#
# def forward(self, x, t):
# """
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ] #"""
# out = self.blocks[0](x) + self.time_mlp(t)
# out = self.blocks[1](out)
# return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
......
......@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
......@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
......@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return conv
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
......@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return conv
conv1x1 = ddpm_conv1x1
conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
......@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return x
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down:
if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# class ResnetBlockDDPMpp(nn.Module):
# """ResBlock adapted from DDPM."""
#
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# conv_shortcut=False,
# dropout=0.1,
# skip_rescale=False,
# init_scale=0.0,
# ):
# super().__init__()
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
# nn.init.zeros_(self.Dense_0.bias)
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch:
# if conv_shortcut:
# self.Conv_2 = conv3x3(in_ch, out_ch)
# else:
# self.NIN_0 = NIN(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.out_ch = out_ch
# self.conv_shortcut = conv_shortcut
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
# h = self.Conv_0(h)
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
# if x.shape[1] != self.out_ch:
# if self.conv_shortcut:
# x = self.Conv_2(x)
# else:
# x = self.NIN_0(x)
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
# class ResnetBlockBigGANpp(nn.Module):
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# up=False,
# down=False,
# dropout=0.1,
# fir=False,
# fir_kernel=(1, 3, 3, 1),
# skip_rescale=True,
# init_scale=0.0,
# ):
# super().__init__()
#
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.up = up
# self.down = down
# self.fir = fir
# self.fir_kernel = fir_kernel
#
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
# nn.init.zeros_(self.Dense_0.bias)
#
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch or up or down:
# self.Conv_2 = conv1x1(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.in_ch = in_ch
# self.out_ch = out_ch
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
#
# if self.up:
# if self.fir:
# h = upsample_2d(h, self.fir_kernel, factor=2)
# x = upsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_upsample_2d(h, factor=2)
# x = naive_upsample_2d(x, factor=2)
# elif self.down:
# if self.fir:
# h = downsample_2d(h, self.fir_kernel, factor=2)
# x = downsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_downsample_2d(h, factor=2)
# x = naive_downsample_2d(x, factor=2)
#
# h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
#
# if self.in_ch != self.out_ch or self.up or self.down:
# x = self.Conv_2(x)
#
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment