Commit ebf3717c authored by Patrick von Platen's avatar Patrick von Platen
Browse files

resnet in one file

parent e5d9baf0
This diff is collapsed.
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResnetBlock, Upsample
def nonlinearity(x): def nonlinearity(x):
...@@ -34,46 +34,46 @@ def Normalize(in_channels): ...@@ -34,46 +34,46 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module): # class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): # def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__() # super().__init__()
self.in_channels = in_channels # self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels # out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels # self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut # self.use_conv_shortcut = conv_shortcut
#
self.norm1 = Normalize(in_channels) # self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) # self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels) # self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout) # self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels: # if self.in_channels != self.out_channels:
if self.use_conv_shortcut: # if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) # self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else: # else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def forward(self, x, temb): # def forward(self, x, temb):
h = x # h = x
h = self.norm1(h) # h = self.norm1(h)
h = nonlinearity(h) # h = nonlinearity(h)
h = self.conv1(h) # h = self.conv1(h)
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] # h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h = self.norm2(h) # h = self.norm2(h)
h = nonlinearity(h) # h = nonlinearity(h)
h = self.dropout(h) # h = self.dropout(h)
h = self.conv2(h) # h = self.conv2(h)
#
if self.in_channels != self.out_channels: # if self.in_channels != self.out_channels:
if self.use_conv_shortcut: # if self.use_conv_shortcut:
x = self.conv_shortcut(x) # x = self.conv_shortcut(x)
else: # else:
x = self.nin_shortcut(x) # x = self.nin_shortcut(x)
#
return x + h # return x + h
class UNetModel(ModelMixin, ConfigMixin): class UNetModel(ModelMixin, ConfigMixin):
...@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
block = nn.ModuleList() block = nn.ModuleList()
attn = nn.ModuleList() attn = nn.ModuleList()
attn_2 = nn.ModuleList()
block_in = ch * in_ch_mult[i_level] block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
...@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down = nn.Module() down = nn.Module()
down.block = block down.block = block
down.attn = attn down.attn = attn
down.attn_2 = attn_2
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
......
...@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin ...@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -96,16 +96,14 @@ def zero_module(module): ...@@ -96,16 +96,14 @@ def zero_module(module):
return module return module
class TimestepBlock(nn.Module): # class TimestepBlock(nn.Module):
""" # """
Any module where forward() takes timestep embeddings as a second argument. # Any module where forward() takes timestep embeddings as a second argument. #"""
""" #
# @abstractmethod
@abstractmethod # def forward(self, x, emb):
def forward(self, x, emb): # """
""" # Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class ResBlock(TimestepBlock): # class ResBlock(TimestepBlock):
""" # """
A residual block that can optionally change the number of channels. # A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing #
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # def __init__(
downsampling. # self,
""" # channels,
# emb_channels,
def __init__( # dropout,
self, # out_channels=None,
channels, # use_conv=False,
emb_channels, # use_scale_shift_norm=False,
dropout, # dims=2,
out_channels=None, # use_checkpoint=False,
use_conv=False, # up=False,
use_scale_shift_norm=False, # down=False,
dims=2, # ):
use_checkpoint=False, # super().__init__()
up=False, # self.channels = channels
down=False, # self.emb_channels = emb_channels
): # self.dropout = dropout
super().__init__() # self.out_channels = out_channels or channels
self.channels = channels # self.use_conv = use_conv
self.emb_channels = emb_channels # self.use_checkpoint = use_checkpoint
self.dropout = dropout # self.use_scale_shift_norm = use_scale_shift_norm
self.out_channels = out_channels or channels #
self.use_conv = use_conv # self.in_layers = nn.Sequential(
self.use_checkpoint = use_checkpoint # normalization(channels, swish=1.0),
self.use_scale_shift_norm = use_scale_shift_norm # nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self.in_layers = nn.Sequential( # )
normalization(channels, swish=1.0), #
nn.Identity(), # self.updown = up or down
conv_nd(dims, channels, self.out_channels, 3, padding=1), #
) # if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.updown = up or down # self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if up: # self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.h_upd = Upsample(channels, use_conv=False, dims=dims) # self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Upsample(channels, use_conv=False, dims=dims) # else:
elif down: # self.h_upd = self.x_upd = nn.Identity()
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") #
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") # self.emb_layers = nn.Sequential(
else: # nn.SiLU(),
self.h_upd = self.x_upd = nn.Identity() # linear(
# emb_channels,
self.emb_layers = nn.Sequential( # 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn.SiLU(), # ),
linear( # )
emb_channels, # self.out_layers = nn.Sequential(
2 * self.out_channels if use_scale_shift_norm else self.out_channels, # normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
), # nn.SiLU() if use_scale_shift_norm else nn.Identity(),
) # nn.Dropout(p=dropout),
self.out_layers = nn.Sequential( # zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), # )
nn.SiLU() if use_scale_shift_norm else nn.Identity(), #
nn.Dropout(p=dropout), # if self.out_channels == channels:
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), # self.skip_connection = nn.Identity()
) # elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if self.out_channels == channels: # else:
self.skip_connection = nn.Identity() # self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif use_conv: #
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) # def forward(self, x, emb):
else: # """
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) # Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
def forward(self, x, emb): # if self.updown:
""" # in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding. # h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. # x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs. # h = in_conv(h)
""" # else:
if self.updown: # h = self.in_layers(x)
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] # emb_out = self.emb_layers(emb).type(h.dtype)
h = in_rest(x) # while len(emb_out.shape) < len(h.shape):
h = self.h_upd(h) # emb_out = emb_out[..., None]
x = self.x_upd(x) # if self.use_scale_shift_norm:
h = in_conv(h) # out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else: # scale, shift = torch.chunk(emb_out, 2, dim=1)
h = self.in_layers(x) # h = out_norm(h) * (1 + scale) + shift
emb_out = self.emb_layers(emb).type(h.dtype) # h = out_rest(h)
while len(emb_out.shape) < len(h.shape): # else:
emb_out = emb_out[..., None] # h = h + emb_out
if self.use_scale_shift_norm: # h = self.out_layers(h)
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] # return self.skip_connection(x) + h
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
......
...@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin ...@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -34,24 +36,24 @@ class Block(torch.nn.Module): ...@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return output * mask return output * mask
class ResnetBlock(torch.nn.Module): # class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): # def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__() # super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) # self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self.block1 = Block(dim, dim_out, groups=groups) # self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups) # self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out: # if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) # self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else: # else:
self.res_conv = torch.nn.Identity() # self.res_conv = torch.nn.Identity()
#
def forward(self, x, mask, time_emb): # def forward(self, x, mask, time_emb):
h = self.block1(x, mask) # h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) # h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask) # h = self.block2(h, mask)
output = h + self.res_conv(x * mask) # output = h + self.res_conv(x * mask)
return output # return output
class Residual(torch.nn.Module): class Residual(torch.nn.Module):
......
...@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin ...@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
def exists(val): def exists(val):
...@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module): ...@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return x[:, :, 0] return x[:, :, 0]
class TimestepBlock(nn.Module): # class TimestepBlock(nn.Module):
""" # """
Any module where forward() takes timestep embeddings as a second argument. # Any module where forward() takes timestep embeddings as a second argument. #"""
""" #
# @abstractmethod
@abstractmethod # def forward(self, x, emb):
def forward(self, x, emb): # """
""" # Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class ResBlock(TimestepBlock): # class A_ResBlock(TimestepBlock):
""" # """
A residual block that can optionally change the number of channels. :param channels: the number of input channels. # A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
a spatial spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for #
downsampling. # def __init__(
""" # self,
# channels,
def __init__( # emb_channels,
self, # dropout,
channels, # out_channels=None,
emb_channels, # use_conv=False,
dropout, # use_scale_shift_norm=False,
out_channels=None, # dims=2,
use_conv=False, # use_checkpoint=False,
use_scale_shift_norm=False, # up=False,
dims=2, # down=False,
use_checkpoint=False, # ):
up=False, # super().__init__()
down=False, # self.channels = channels
): # self.emb_channels = emb_channels
super().__init__() # self.dropout = dropout
self.channels = channels # self.out_channels = out_channels or channels
self.emb_channels = emb_channels # self.use_conv = use_conv
self.dropout = dropout # self.use_checkpoint = use_checkpoint
self.out_channels = out_channels or channels # self.use_scale_shift_norm = use_scale_shift_norm
self.use_conv = use_conv #
self.use_checkpoint = use_checkpoint # self.in_layers = nn.Sequential(
self.use_scale_shift_norm = use_scale_shift_norm # normalization(channels),
# nn.SiLU(),
self.in_layers = nn.Sequential( # conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization(channels), # )
nn.SiLU(), #
conv_nd(dims, channels, self.out_channels, 3, padding=1), # self.updown = up or down
) #
# if up:
self.updown = up or down # self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if up: # elif down:
self.h_upd = Upsample(channels, use_conv=False, dims=dims) # self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Upsample(channels, use_conv=False, dims=dims) # self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif down: # else:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") # self.h_upd = self.x_upd = nn.Identity()
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") #
else: # self.emb_layers = nn.Sequential(
self.h_upd = self.x_upd = nn.Identity() # nn.SiLU(),
# linear(
self.emb_layers = nn.Sequential( # emb_channels,
nn.SiLU(), # 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear( # ),
emb_channels, # )
2 * self.out_channels if use_scale_shift_norm else self.out_channels, # self.out_layers = nn.Sequential(
), # normalization(self.out_channels),
) # nn.SiLU(),
self.out_layers = nn.Sequential( # nn.Dropout(p=dropout),
normalization(self.out_channels), # zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn.SiLU(), # )
nn.Dropout(p=dropout), #
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), # if self.out_channels == channels:
) # self.skip_connection = nn.Identity()
# elif use_conv:
if self.out_channels == channels: # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self.skip_connection = nn.Identity() # else:
elif use_conv: # self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) #
else: # def forward(self, x, emb):
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) # if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def forward(self, x, emb): # h = in_rest(x)
if self.updown: # h = self.h_upd(h)
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] # x = self.x_upd(x)
h = in_rest(x) # h = in_conv(h)
h = self.h_upd(h) # else:
x = self.x_upd(x) # h = self.in_layers(x)
h = in_conv(h) # emb_out = self.emb_layers(emb).type(h.dtype)
else: # while len(emb_out.shape) < len(h.shape):
h = self.in_layers(x) # emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype) # if self.use_scale_shift_norm:
while len(emb_out.shape) < len(h.shape): # out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out = emb_out[..., None] # scale, shift = torch.chunk(emb_out, 2, dim=1)
if self.use_scale_shift_norm: # h = out_norm(h) * (1 + scale) + shift
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] # h = out_rest(h)
scale, shift = torch.chunk(emb_out, 2, dim=1) # else:
h = out_norm(h) * (1 + scale) + shift # h = h + emb_out
h = out_rest(h) # h = self.out_layers(h)
else: # return self.skip_connection(x) + h
h = h + emb_out #
h = self.out_layers(h)
return self.skip_connection(x) + h
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
......
...@@ -6,6 +6,7 @@ import torch.nn as nn ...@@ -6,6 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import ResidualTemporalBlock
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
...@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module): ...@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
class ResidualTemporalBlock(nn.Module): # class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): # def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__() # super().__init__()
#
self.blocks = nn.ModuleList( # self.blocks = nn.ModuleList(
[ # [
Conv1dBlock(inp_channels, out_channels, kernel_size), # Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size), # Conv1dBlock(out_channels, out_channels, kernel_size),
] # ]
) # )
#
self.time_mlp = nn.Sequential( # self.time_mlp = nn.Sequential(
nn.Mish(), # nn.Mish(),
nn.Linear(embed_dim, out_channels), # nn.Linear(embed_dim, out_channels),
RearrangeDim(), # RearrangeDim(),
# Rearrange("batch t -> batch t 1"), # Rearrange("batch t -> batch t 1"),
) # )
#
self.residual_conv = ( # self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() # nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
) # )
#
def forward(self, x, t): # def forward(self, x, t):
""" # """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x # x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ] out_channels x horizon ] #"""
""" # out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[0](x) + self.time_mlp(t) # out = self.blocks[1](out)
out = self.blocks[1](out) # return out + self.residual_conv(x)
return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
......
...@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin ...@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization.""" """1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
...@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad ...@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return conv return conv
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization.""" """3x3 convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
...@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc ...@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return conv return conv
conv1x1 = ddpm_conv1x1
conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y): def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y) return torch.einsum(einsum_str, x, y)
...@@ -494,135 +491,135 @@ class Downsample(nn.Module): ...@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return x return x
class ResnetBlockDDPMpp(nn.Module): # class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM.""" # """ResBlock adapted from DDPM."""
#
def __init__( # def __init__(
self, # self,
act, # act,
in_ch, # in_ch,
out_ch=None, # out_ch=None,
temb_dim=None, # temb_dim=None,
conv_shortcut=False, # conv_shortcut=False,
dropout=0.1, # dropout=0.1,
skip_rescale=False, # skip_rescale=False,
init_scale=0.0, # init_scale=0.0,
): # ):
super().__init__() # super().__init__()
out_ch = out_ch if out_ch else in_ch # out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch) # self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None: # if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) # self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) # self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias) # nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) # self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) # self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) # self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch: # if in_ch != out_ch:
if conv_shortcut: # if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch) # self.Conv_2 = conv3x3(in_ch, out_ch)
else: # else:
self.NIN_0 = NIN(in_ch, out_ch) # self.NIN_0 = NIN(in_ch, out_ch)
#
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
self.act = act # self.act = act
self.out_ch = out_ch # self.out_ch = out_ch
self.conv_shortcut = conv_shortcut # self.conv_shortcut = conv_shortcut
#
def forward(self, x, temb=None): # def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x)) # h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h) # h = self.Conv_0(h)
if temb is not None: # if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None] # h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h)) # h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h) # h = self.Dropout_0(h)
h = self.Conv_1(h) # h = self.Conv_1(h)
if x.shape[1] != self.out_ch: # if x.shape[1] != self.out_ch:
if self.conv_shortcut: # if self.conv_shortcut:
x = self.Conv_2(x) # x = self.Conv_2(x)
else: # else:
x = self.NIN_0(x) # x = self.NIN_0(x)
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
class ResnetBlockBigGANpp(nn.Module): # class ResnetBlockBigGANpp(nn.Module):
def __init__( # def __init__(
self, # self,
act, # act,
in_ch, # in_ch,
out_ch=None, # out_ch=None,
temb_dim=None, # temb_dim=None,
up=False, # up=False,
down=False, # down=False,
dropout=0.1, # dropout=0.1,
fir=False, # fir=False,
fir_kernel=(1, 3, 3, 1), # fir_kernel=(1, 3, 3, 1),
skip_rescale=True, # skip_rescale=True,
init_scale=0.0, # init_scale=0.0,
): # ):
super().__init__() # super().__init__()
#
out_ch = out_ch if out_ch else in_ch # out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up # self.up = up
self.down = down # self.down = down
self.fir = fir # self.fir = fir
self.fir_kernel = fir_kernel # self.fir_kernel = fir_kernel
#
self.Conv_0 = conv3x3(in_ch, out_ch) # self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None: # if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) # self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) # self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) # nn.init.zeros_(self.Dense_0.bias)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) # self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) # self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) # self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down: # if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) # self.Conv_2 = conv1x1(in_ch, out_ch)
#
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
self.act = act # self.act = act
self.in_ch = in_ch # self.in_ch = in_ch
self.out_ch = out_ch # self.out_ch = out_ch
#
def forward(self, x, temb=None): # def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x)) # h = self.act(self.GroupNorm_0(x))
#
if self.up: # if self.up:
if self.fir: # if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2) # h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2) # x = upsample_2d(x, self.fir_kernel, factor=2)
else: # else:
h = naive_upsample_2d(h, factor=2) # h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2) # x = naive_upsample_2d(x, factor=2)
elif self.down: # elif self.down:
if self.fir: # if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2) # h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2) # x = downsample_2d(x, self.fir_kernel, factor=2)
else: # else:
h = naive_downsample_2d(h, factor=2) # h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2) # x = naive_downsample_2d(x, factor=2)
#
h = self.Conv_0(h) # h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
if temb is not None: # if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None] # h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h)) # h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h) # h = self.Dropout_0(h)
h = self.Conv_1(h) # h = self.Conv_1(h)
#
if self.in_ch != self.out_ch or self.up or self.down: # if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x) # x = self.Conv_2(x)
#
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment