Unverified Commit d224c637 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Resnet => Resnet2D (#66)

parent 44705a64
......@@ -176,7 +176,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
class ResnetBlock(nn.Module):
class ResnetBlock2D(nn.Module):
def __init__(
self,
*,
......
......@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, Upsample
def nonlinearity(x):
......@@ -89,7 +89,7 @@ class UNetModel(ModelMixin, ConfigMixin):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
......@@ -106,11 +106,11 @@ class UNetModel(ModelMixin, ConfigMixin):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock(
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
......@@ -125,7 +125,7 @@ class UNetModel(ModelMixin, ConfigMixin):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
ResnetBlock2D(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
......
......@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
def convert_module_to_f16(l):
......@@ -88,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
......@@ -177,7 +177,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
......@@ -206,7 +206,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=out_ch,
dropout=dropout,
......@@ -229,7 +229,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
......@@ -245,7 +245,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
),
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
......@@ -262,7 +262,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock(
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
......@@ -287,7 +287,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=out_ch,
dropout=dropout,
......
......@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, Upsample
class Mish(torch.nn.Module):
......@@ -84,7 +84,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(
ResnetBlock2D(
in_channels=dim_in,
out_channels=dim_out,
temb_channels=dim,
......@@ -94,7 +94,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
ResnetBlock2D(
in_channels=dim_out,
out_channels=dim_out,
temb_channels=dim,
......@@ -111,7 +111,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(
self.mid_block1 = ResnetBlock2D(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
......@@ -122,7 +122,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True,
)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(
self.mid_block2 = ResnetBlock2D(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
......@@ -137,7 +137,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(
ResnetBlock2D(
in_channels=dim_out * 2,
out_channels=dim_in,
temb_channels=dim,
......@@ -147,7 +147,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
ResnetBlock2D(
in_channels=dim_in,
out_channels=dim_in,
temb_channels=dim,
......
......@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
# from .resnet import ResBlock
......@@ -148,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
......@@ -310,7 +310,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
......@@ -367,7 +367,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
......@@ -385,7 +385,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResnetBlock(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
......@@ -402,7 +402,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock(
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
......
......@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k):
......@@ -345,7 +345,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
......@@ -364,7 +364,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level != self.num_resolutions - 1:
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
......@@ -391,7 +391,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch = hs_c[-1]
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
......@@ -403,7 +403,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
modules.append(AttnBlock(channels=in_ch))
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
......@@ -421,7 +421,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
......@@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level != 0:
modules.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
......
......@@ -5,7 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .resnet import Downsample, ResnetBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, Upsample
def nonlinearity(x):
......@@ -54,7 +54,7 @@ class Encoder(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
......@@ -71,11 +71,11 @@ class Encoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock(
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
......@@ -152,11 +152,11 @@ class Decoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock(
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
......@@ -168,7 +168,7 @@ class Decoder(nn.Module):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment