"...git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "f7ac7b604a6c8e96a8c2ee8b502fa725d3899df8"
Unverified Commit 321f9791 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Downsample / Upsample - clean to 1D and 2D (#68)

* make unet rl work

* uploaad files / code

* upload files

* make style correct

* finish
parent c524244f
...@@ -6,46 +6,7 @@ import torch.nn as nn ...@@ -6,46 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs): class Upsample2D(nn.Module):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Upsample(nn.Module):
""" """
An upsampling layer with an optional convolution. An upsampling layer with an optional convolution.
...@@ -54,21 +15,21 @@ class Upsample(nn.Module): ...@@ -54,21 +15,21 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
conv = None conv = None
if use_conv_transpose: if use_conv_transpose:
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
else: else:
...@@ -79,11 +40,9 @@ class Upsample(nn.Module): ...@@ -79,11 +40,9 @@ class Upsample(nn.Module):
if self.use_conv_transpose: if self.use_conv_transpose:
return self.conv(x) return self.conv(x)
if self.dims == 3: x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
x = self.conv(x) x = self.conv(x)
...@@ -93,7 +52,7 @@ class Upsample(nn.Module): ...@@ -93,7 +52,7 @@ class Upsample(nn.Module):
return x return x
class Downsample(nn.Module): class Downsample2D(nn.Module):
""" """
A downsampling layer with an optional convolution. A downsampling layer with an optional convolution.
...@@ -102,22 +61,22 @@ class Downsample(nn.Module): ...@@ -102,22 +61,22 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims
self.padding = padding self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2) stride = 2
self.name = name self.name = name
if use_conv: if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0": elif name == "Conv2d_0":
...@@ -127,10 +86,11 @@ class Downsample(nn.Module): ...@@ -127,10 +86,11 @@ class Downsample(nn.Module):
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0": elif self.name == "Conv2d_0":
...@@ -139,8 +99,204 @@ class Downsample(nn.Module): ...@@ -139,8 +99,204 @@ class Downsample(nn.Module):
return self.op(x) return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
# TODO (patil-suraj): needs test # TODO (patil-suraj): needs test
# class Upsample1d(nn.Module): # class Upsample2D1d(nn.Module):
# def __init__(self, dim): # def __init__(self, dim):
# super().__init__() # super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) # self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
...@@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module): ...@@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else: else:
self.upsample = Upsample(in_channels, use_conv=False, dims=2) self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down: elif self.down:
if kernel == "fir": if kernel == "fir":
fir_kernel = (1, 3, 3, 1) fir_kernel = (1, 3, 3, 1)
...@@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module): ...@@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else: else:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
...@@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module): ...@@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module):
else: else:
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm: elif self.overwrite_for_ldm:
dims = 2
channels = in_channels channels = in_channels
emb_channels = temb_channels emb_channels = temb_channels
use_scale_shift_norm = False use_scale_shift_norm = False
...@@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module): ...@@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels, swish=1.0), normalization(channels, swish=1.0),
nn.Identity(), nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1), nn.Conv2d(channels, self.out_channels, 3, padding=1),
) )
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
...@@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module): ...@@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module):
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(), nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
) )
if self.out_channels == in_channels: if self.out_channels == in_channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
elif self.overwrite_for_score_vde: elif self.overwrite_for_score_vde:
in_ch = in_channels in_ch = in_channels
out_ch = out_channels out_ch = out_channels
...@@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, ...@@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
def upsample_2d(x, k=None, factor=2, gain=1): def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Args: Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
...@@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1): ...@@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1):
def downsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Args: Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
......
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x): def nonlinearity(x):
...@@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down.block = block down.block = block
down.attn = attn down.attn = attn
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
...@@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv) up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order
......
...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin ...@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
) )
) )
ch = out_ch ch = out_ch
...@@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
......
...@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin ...@@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
) )
) )
...@@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample2D(dim_in, use_conv_transpose=True),
] ]
) )
) )
......
...@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin ...@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
# from .resnet import ResBlock # from .resnet import ResBlock
...@@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op") Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
) )
) )
ch = out_ch ch = out_ch
...@@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)) layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
......
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResidualTemporalBlock, Upsample from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
...@@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(), Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
] ]
) )
) )
...@@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(), Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
] ]
) )
) )
......
...@@ -21,13 +21,12 @@ import math ...@@ -21,13 +21,12 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def _setup_kernel(k): def _setup_kernel(k):
...@@ -40,96 +39,6 @@ def _setup_kernel(k): ...@@ -40,96 +39,6 @@ def _setup_kernel(k):
return k return k
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
...@@ -183,46 +92,6 @@ class Combine(nn.Module): ...@@ -183,46 +92,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class FirUpsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model""" """NCSN++ model"""
...@@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir: if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
Up_sample = functools.partial(Upsample, name="Conv2d_0") Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
...@@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
pyramid_upsample = functools.partial(Up_sample, use_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir: if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x): def nonlinearity(x):
...@@ -65,7 +65,7 @@ class Encoder(nn.Module): ...@@ -65,7 +65,7 @@ class Encoder(nn.Module):
down.block = block down.block = block
down.attn = attn down.attn = attn
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
...@@ -179,7 +179,7 @@ class Decoder(nn.Module): ...@@ -179,7 +179,7 @@ class Decoder(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv) up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order
......
...@@ -137,7 +137,7 @@ class ResidualBlock(nn.Module): ...@@ -137,7 +137,7 @@ class ResidualBlock(nn.Module):
# Dilated conv layer # Dilated conv layer
h = self.dilated_conv_layer(h) h = self.dilated_conv_layer(h)
# Upsample spectrogram to size of audio # Upsample2D spectrogram to size of audio
mel_spec = torch.unsqueeze(mel_spec, dim=1) mel_spec = torch.unsqueeze(mel_spec, dim=1)
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False) mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False) mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample, Upsample from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
...@@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase): ...@@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase):
) )
class UpsampleBlockTests(unittest.TestCase): class Upsample2DBlockTests(unittest.TestCase):
def test_upsample_default(self): def test_upsample_default(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False) upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
...@@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase): ...@@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv(self): def test_upsample_with_conv(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True) upsample = Upsample2D(channels=32, use_conv=True)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
...@@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase): ...@@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv_out_dim(self): def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64) upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
...@@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase): ...@@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_transpose(self): def test_upsample_with_transpose(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True) upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
...@@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase): ...@@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class DownsampleBlockTests(unittest.TestCase): class Downsample2DBlockTests(unittest.TestCase):
def test_downsample_default(self): def test_downsample_default(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=False) downsample = Downsample2D(channels=32, use_conv=False)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
...@@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase): ...@@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv(self): def test_downsample_with_conv(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True) downsample = Downsample2D(channels=32, use_conv=True)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
...@@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase): ...@@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_pad1(self): def test_downsample_with_conv_pad1(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, padding=1) downsample = Downsample2D(channels=32, use_conv=True, padding=1)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
...@@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase): ...@@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_out_dim(self): def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, out_channels=16) downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment