"src/vscode:/vscode.git/clone" did not exist on "1fbcc78d6e55613b902015ff65a1d850594fa859"
Commit 046dc430 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent c174bcf4
......@@ -166,8 +166,8 @@ class Downsample(nn.Module):
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
......@@ -192,8 +192,8 @@ class Downsample(nn.Module):
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
......@@ -342,7 +342,20 @@ class ResBlock(TimestepBlock):
# unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
......
......@@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample
from .resnet import ResnetBlock
from .resnet import Upsample
from .resnet import Downsample, ResnetBlock, Upsample
class Mish(torch.nn.Module):
......@@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_out,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
......@@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block1 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block2 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(
in_channels=dim_out * 2,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True),
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment