Commit 046dc430 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent c174bcf4
...@@ -166,8 +166,8 @@ class Downsample(nn.Module): ...@@ -166,8 +166,8 @@ class Downsample(nn.Module):
# #
# class GlideUpsample(nn.Module): # class GlideUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param # An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None): # def __init__(self, channels, use_conv, dims=2, out_channels=None):
...@@ -192,8 +192,8 @@ class Downsample(nn.Module): ...@@ -192,8 +192,8 @@ class Downsample(nn.Module):
# #
# class LDMUpsample(nn.Module): # class LDMUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # # An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# 3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): # def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
...@@ -342,7 +342,20 @@ class ResBlock(TimestepBlock): ...@@ -342,7 +342,20 @@ class ResBlock(TimestepBlock):
# unet.py and unet_grad_tts.py # unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False): def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
):
super().__init__() super().__init__()
self.pre_norm = pre_norm self.pre_norm = pre_norm
self.in_channels = in_channels self.in_channels = in_channels
......
...@@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin ...@@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample from .resnet import Downsample, ResnetBlock, Upsample
from .resnet import ResnetBlock
from .resnet import Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append( self.downs.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), ResnetBlock(
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), in_channels=dim_in,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_out,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
...@@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) self.mid_block1 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) self.mid_block2 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), ResnetBlock(
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), in_channels=dim_out * 2,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample(dim_in, use_conv_transpose=True),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment