Commit 183056f2 authored by patil-suraj's avatar patil-suraj
Browse files

consolidate Upsample

parent dc7c49e4
...@@ -64,7 +64,7 @@ class Upsample(nn.Module): ...@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None): def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
......
...@@ -31,6 +31,7 @@ from tqdm import tqdm ...@@ -31,6 +31,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Upsample
def nonlinearity(x): def nonlinearity(x):
...@@ -42,20 +43,6 @@ def Normalize(in_channels): ...@@ -42,20 +43,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module): class Downsample(nn.Module):
def __init__(self, in_channels, with_conv): def __init__(self, in_channels, with_conv):
super().__init__() super().__init__()
...@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv) up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Upsample
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module): class Downsample(nn.Module):
""" """
A downsampling layer with an optional convolution. A downsampling layer with an optional convolution.
...@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock): ...@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down self.updown = up or down
if up: if up:
self.h_upd = Upsample(channels, False, dims) self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down: elif down:
self.h_upd = Downsample(channels, False, dims) self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims)
...@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
......
...@@ -10,6 +10,7 @@ except: ...@@ -10,6 +10,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -17,15 +18,6 @@ class Mish(torch.nn.Module): ...@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
class Upsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(torch.nn.Module): class Downsample(torch.nn.Module):
def __init__(self, dim): def __init__(self, dim):
super(Downsample, self).__init__() super(Downsample, self).__init__()
...@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in), Upsample(dim_in, use_conv_transpose=True),
] ]
) )
) )
......
...@@ -17,6 +17,7 @@ except: ...@@ -17,6 +17,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Upsample
def exists(val): def exists(val):
...@@ -377,35 +378,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -377,35 +378,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module): class Downsample(nn.Module):
""" """
A downsampling layer with an optional convolution. A downsampling layer with an optional convolution.
...@@ -480,8 +452,8 @@ class ResBlock(TimestepBlock): ...@@ -480,8 +452,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down self.updown = up or down
if up: if up:
self.h_upd = Upsample(channels, False, dims) self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down: elif down:
self.h_upd = Downsample(channels, False, dims) self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims)
...@@ -948,7 +920,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -948,7 +920,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers import ( from diffusers import ( # GradTTSPipeline,
BDDMPipeline, BDDMPipeline,
DDIMPipeline, DDIMPipeline,
DDIMScheduler, DDIMScheduler,
...@@ -30,7 +30,6 @@ from diffusers import ( ...@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline, GlidePipeline,
GlideSuperResUNetModel, GlideSuperResUNetModel,
GlideTextToImageUNetModel, GlideTextToImageUNetModel,
GradTTSPipeline,
GradTTSScheduler, GradTTSScheduler,
LatentDiffusionPipeline, LatentDiffusionPipeline,
NCSNpp, NCSNpp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment