Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
046dc430
"src/vscode:/vscode.git/clone" did not exist on "1fbcc78d6e55613b902015ff65a1d850594fa859"
Commit
046dc430
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
make style
parent
c174bcf4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
14 deletions
+79
-14
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+18
-5
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+61
-9
No files found.
src/diffusers/models/resnet.py
View file @
046dc430
...
...
@@ -166,8 +166,8 @@ class Downsample(nn.Module):
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
#
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
#
use_conv
:
a
bool
determining
if
a
convolution
is
# applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
...
...
@@ -192,8 +192,8 @@ class Downsample(nn.Module):
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
#
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
#
use_conv
:
a
bool
determining
if
a
convolution
is
applied
.
:
param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
# If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
...
...
@@ -342,7 +342,20 @@ class ResBlock(TimestepBlock):
# unet.py and unet_grad_tts.py
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
temb_channels
=
512
,
groups
=
32
,
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
overwrite_for_grad_tts
=
False
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
temb_channels
=
512
,
groups
=
32
,
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
overwrite_for_grad_tts
=
False
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
...
...
src/diffusers/models/unet_grad_tts.py
View file @
046dc430
...
...
@@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlock
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
ResnetBlock
(
in_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
...
...
@@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
self
.
mid_block1
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
)
self
.
mid_block2
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
),
ResnetBlock
(
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment