"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ccb7f45a3570b2175d8e8def66629528d557da3c"
Commit 466214d2 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Remove bogus file

parent 4e125f72
...@@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin ...@@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample from .resnet import Downsample
from .resnet import ResnetBlock as ResnetBlockNew from .resnet import ResnetBlock
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample from .resnet import Upsample
...@@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups = torch.nn.ModuleList([]) self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append( self.downs.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
# ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
# ResnetBlock(dim_out, dim_out, time_emb_dim=dim), ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
...@@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
# ResnetBlock(dim_in, dim_in, time_emb_dim=dim), ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample(dim_in, use_conv_transpose=True),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment