Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
466214d2
Commit
466214d2
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
Remove bogus file
parent
4e125f72
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
19 deletions
+7
-19
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+7
-19
No files found.
src/diffusers/models/unet_grad_tts.py
View file @
466214d2
...
...
@@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlock
as
ResnetBlockNew
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
ResnetBlock
from
.resnet
import
Upsample
...
...
@@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
ups
=
torch
.
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
(
[
# ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
# ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
ResnetBlockNew
(
in_channels
=
dim_in
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlockNew
(
in_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
...
...
@@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim
=
dims
[
-
1
]
# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self
.
mid_block1
=
ResnetBlockNew
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
self
.
mid_block1
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
New
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
self
.
mid_block2
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
[
# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
# ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
ResnetBlockNew
(
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlockNew
(
in_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment