Unverified Commit d224c637 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Resnet => Resnet2D (#66)

parent 44705a64
...@@ -176,7 +176,7 @@ class Downsample(nn.Module): ...@@ -176,7 +176,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now! # => All 2D-Resnets are included here now!
class ResnetBlock(nn.Module): class ResnetBlock2D(nn.Module):
def __init__( def __init__(
self, self,
*, *,
......
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample from .resnet import Downsample, ResnetBlock2D, Upsample
def nonlinearity(x): def nonlinearity(x):
...@@ -89,7 +89,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -89,7 +89,7 @@ class UNetModel(ModelMixin, ConfigMixin):
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
block.append( block.append(
ResnetBlock( ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
) )
) )
...@@ -106,11 +106,11 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -106,11 +106,11 @@ class UNetModel(ModelMixin, ConfigMixin):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock( self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock( self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
...@@ -125,7 +125,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -125,7 +125,7 @@ class UNetModel(ModelMixin, ConfigMixin):
if i_block == self.num_res_blocks: if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level] skip_in = ch * in_ch_mult[i_level]
block.append( block.append(
ResnetBlock( ResnetBlock2D(
in_channels=block_in + skip_in, in_channels=block_in + skip_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
......
...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin ...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -88,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -88,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, encoder_out=None): def forward(self, x, emb, encoder_out=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock): if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, AttentionBlock): elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out) x = layer(x, encoder_out)
...@@ -177,7 +177,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -177,7 +177,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dropout=dropout, dropout=dropout,
...@@ -206,7 +206,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -206,7 +206,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=out_ch, out_channels=out_ch,
dropout=dropout, dropout=dropout,
...@@ -229,7 +229,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -229,7 +229,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch self._feature_size += ch
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
...@@ -245,7 +245,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -245,7 +245,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=transformer_dim, encoder_channels=transformer_dim,
), ),
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
...@@ -262,7 +262,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -262,7 +262,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResnetBlock( ResnetBlock2D(
in_channels=ch + ich, in_channels=ch + ich,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dropout=dropout, dropout=dropout,
...@@ -287,7 +287,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -287,7 +287,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=out_ch, out_channels=out_ch,
dropout=dropout, dropout=dropout,
......
...@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin ...@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample from .resnet import Downsample, ResnetBlock2D, Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -84,7 +84,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -84,7 +84,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append( self.downs.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock( ResnetBlock2D(
in_channels=dim_in, in_channels=dim_in,
out_channels=dim_out, out_channels=dim_out,
temb_channels=dim, temb_channels=dim,
...@@ -94,7 +94,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -94,7 +94,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish", non_linearity="mish",
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
ResnetBlock( ResnetBlock2D(
in_channels=dim_out, in_channels=dim_out,
out_channels=dim_out, out_channels=dim_out,
temb_channels=dim, temb_channels=dim,
...@@ -111,7 +111,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -111,7 +111,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock( self.mid_block1 = ResnetBlock2D(
in_channels=mid_dim, in_channels=mid_dim,
out_channels=mid_dim, out_channels=mid_dim,
temb_channels=dim, temb_channels=dim,
...@@ -122,7 +122,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -122,7 +122,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
) )
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock( self.mid_block2 = ResnetBlock2D(
in_channels=mid_dim, in_channels=mid_dim,
out_channels=mid_dim, out_channels=mid_dim,
temb_channels=dim, temb_channels=dim,
...@@ -137,7 +137,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -137,7 +137,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups.append( self.ups.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock( ResnetBlock2D(
in_channels=dim_out * 2, in_channels=dim_out * 2,
out_channels=dim_in, out_channels=dim_in,
temb_channels=dim, temb_channels=dim,
...@@ -147,7 +147,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -147,7 +147,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish", non_linearity="mish",
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
ResnetBlock( ResnetBlock2D(
in_channels=dim_in, in_channels=dim_in,
out_channels=dim_in, out_channels=dim_in,
temb_channels=dim, temb_channels=dim,
......
...@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin ...@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
# from .resnet import ResBlock # from .resnet import ResBlock
...@@ -148,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -148,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock): if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context)
...@@ -310,7 +310,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -310,7 +310,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dropout=dropout, dropout=dropout,
...@@ -367,7 +367,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -367,7 +367,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=None, out_channels=None,
dropout=dropout, dropout=dropout,
...@@ -385,7 +385,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -385,7 +385,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if not use_spatial_transformer if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResnetBlock( ResnetBlock2D(
in_channels=ch, in_channels=ch,
out_channels=None, out_channels=None,
dropout=dropout, dropout=dropout,
...@@ -402,7 +402,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -402,7 +402,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResnetBlock( ResnetBlock2D(
in_channels=ch + ich, in_channels=ch + ich,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dropout=dropout, dropout=dropout,
......
...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin ...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k): def _setup_kernel(k):
...@@ -345,7 +345,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -345,7 +345,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_block in range(num_res_blocks): for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
...@@ -364,7 +364,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -364,7 +364,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -391,7 +391,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -391,7 +391,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch = hs_c[-1] in_ch = hs_c[-1]
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -403,7 +403,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -403,7 +403,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
) )
modules.append(AttnBlock(channels=in_ch)) modules.append(AttnBlock(channels=in_ch))
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
...@@ -421,7 +421,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -421,7 +421,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
out_ch = nf * ch_mult[i_level] out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop() in_ch = in_ch + hs_c.pop()
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
out_channels=out_ch, out_channels=out_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
...@@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level != 0: if i_level != 0:
modules.append( modules.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_ch, in_channels=in_ch,
temb_channels=4 * nf, temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0), output_scale_factor=np.sqrt(2.0),
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .resnet import Downsample, ResnetBlock, Upsample from .resnet import Downsample, ResnetBlock2D, Upsample
def nonlinearity(x): def nonlinearity(x):
...@@ -54,7 +54,7 @@ class Encoder(nn.Module): ...@@ -54,7 +54,7 @@ class Encoder(nn.Module):
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
block.append( block.append(
ResnetBlock( ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
) )
) )
...@@ -71,11 +71,11 @@ class Encoder(nn.Module): ...@@ -71,11 +71,11 @@ class Encoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock( self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock( self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
...@@ -152,11 +152,11 @@ class Decoder(nn.Module): ...@@ -152,11 +152,11 @@ class Decoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock( self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock( self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
...@@ -168,7 +168,7 @@ class Decoder(nn.Module): ...@@ -168,7 +168,7 @@ class Decoder(nn.Module):
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
block.append( block.append(
ResnetBlock( ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment