Commit d1fb3093 authored by patil-suraj's avatar patil-suraj
Browse files

consolidate downsample

parent 7b9b946c
......@@ -31,7 +31,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def nonlinearity(x):
......@@ -43,24 +43,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
......@@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
......
......@@ -8,7 +8,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def convert_module_to_f16(l):
......@@ -126,34 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
......@@ -205,8 +177,8 @@ class ResBlock(TimestepBlock):
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
......@@ -463,7 +435,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch
......
import torch
from numpy import pad
try:
......@@ -10,7 +11,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
class Mish(torch.nn.Module):
......@@ -18,15 +19,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Downsample(torch.nn.Module):
def __init__(self, dim):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Rezero(torch.nn.Module):
def __init__(self, fn):
super(Rezero, self).__init__()
......@@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity(),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
)
)
......
......@@ -17,7 +17,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def exists(val):
......@@ -380,33 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
......@@ -457,8 +430,8 @@ class ResBlock(TimestepBlock):
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
......@@ -825,7 +798,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch
......@@ -1098,7 +1073,9 @@ class EncoderUNetModel(nn.Module):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment