Unverified Commit c524244f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Resnet] Remove unnecessary functions / classes (#67)

Remove unnecessary functions / classes
parent d224c637
from abc import abstractmethod
from functools import partial
import numpy as np
......@@ -46,30 +45,6 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
......@@ -216,9 +191,9 @@ class ResnetBlock2D(nn.Module):
groups_out = groups
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
......@@ -227,12 +202,12 @@ class ResnetBlock2D(nn.Module):
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
......
......@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, Upsample
def convert_module_to_f16(l):
......@@ -81,14 +81,14 @@ def zero_module(module):
return module
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
......
......@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, TimestepBlock, Upsample
from .resnet import Downsample, ResnetBlock2D, Upsample
# from .resnet import ResBlock
......@@ -141,14 +141,14 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock2D):
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment