Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
eb90d3be
Unverified
Commit
eb90d3be
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #44 from huggingface/unify_resnet
Unify resnet [GradTTS & Unet.py]
parents
8cba133f
df2e145e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
203 additions
and
93 deletions
+203
-93
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+136
-78
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+67
-15
No files found.
src/diffusers/models/resnet.py
View file @
eb90d3be
...
...
@@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
def
Normalize
(
in_channels
,
num_groups
=
32
,
eps
=
1e-6
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
def
nonlinearity
(
x
,
swish
=
1.0
):
...
...
@@ -166,8 +166,8 @@ class Downsample(nn.Module):
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
#
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
#
use_conv
:
a
bool
determining
if
a
convolution
is
# applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
...
...
@@ -192,8 +192,8 @@ class Downsample(nn.Module):
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
#
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
#
use_conv
:
a
bool
determining
if
a
convolution
is
applied
.
:
param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
# If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
...
...
@@ -340,40 +340,118 @@ class ResBlock(TimestepBlock):
return
self
.
skip_connection
(
x
)
+
h
# unet.py
# unet.py
and unet_grad_tts.py
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
temb_channels
=
512
,
groups
=
32
,
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
overwrite_for_grad_tts
=
False
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
self
.
is_overwritten
=
False
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim_out
=
out_channels
time_emb_dim
=
temb_channels
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
pre_norm
=
pre_norm
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
block1
.
block
[
0
].
bias
.
data
self
.
norm1
.
weight
.
data
=
self
.
block1
.
block
[
1
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
block1
.
block
[
1
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
block2
.
block
[
0
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
block2
.
block
[
0
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
block2
.
block
[
1
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
block2
.
block
[
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
mlp
[
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
mlp
[
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
None
):
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
h
=
x
h
=
h
*
mask
if
mask
is
not
None
else
h
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
*
mask
if
mask
is
not
None
else
h
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
x
=
x
*
mask
if
mask
is
not
None
else
x
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
...
...
@@ -383,58 +461,17 @@ class ResnetBlock(nn.Module):
return
x
+
h
# unet_grad_tts.py
class
ResnetBlockGradTTS
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlockGradTTS
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
# TODO(Patrick) - just there to convert the weights; can delete afterward
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
...
...
@@ -570,6 +607,39 @@ class ResnetBlockDDPMpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
# HELPER Modules
...
...
@@ -617,18 +687,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
class
Conv1dBlock
(
nn
.
Module
):
"""
Conv1d --> GroupNorm --> Mish
...
...
src/diffusers/models/unet_grad_tts.py
View file @
eb90d3be
...
...
@@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
ResnetBlock
(
in_channels
=
dim_out
,
out_channels
=
dim_out
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
...
...
@@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_block1
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_block2
=
ResnetBlock
(
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
in_channels
=
dim_out
*
2
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
ResnetBlock
(
in_channels
=
dim_in
,
out_channels
=
dim_in
,
temb_channels
=
dim
,
groups
=
8
,
pre_norm
=
False
,
eps
=
1e-5
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
...
...
@@ -135,8 +187,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks
=
[
mask
]
for
resnet1
,
resnet2
,
attn
,
downsample
in
self
.
downs
:
mask_down
=
masks
[
-
1
]
x
=
resnet1
(
x
,
mask_down
,
t
)
x
=
resnet2
(
x
,
mask_down
,
t
)
x
=
resnet1
(
x
,
t
,
mask_down
)
x
=
resnet2
(
x
,
t
,
mask_down
)
x
=
attn
(
x
)
hiddens
.
append
(
x
)
x
=
downsample
(
x
*
mask_down
)
...
...
@@ -144,15 +196,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
mask_mid
,
t
)
x
=
self
.
mid_block1
(
x
,
t
,
mask_mid
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
mask_mid
,
t
)
x
=
self
.
mid_block2
(
x
,
t
,
mask_mid
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
x
=
torch
.
cat
((
x
,
hiddens
.
pop
()),
dim
=
1
)
x
=
resnet1
(
x
,
mask_up
,
t
)
x
=
resnet2
(
x
,
mask_up
,
t
)
x
=
resnet1
(
x
,
t
,
mask_up
)
x
=
resnet2
(
x
,
t
,
mask_up
)
x
=
attn
(
x
)
x
=
upsample
(
x
*
mask_up
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment