Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c352faea
Unverified
Commit
c352faea
authored
Jul 04, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 04, 2022
Browse files
Add MidBlock to Grad-TTS (#74)
Finish
parent
10798663
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
18 deletions
+33
-18
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+17
-5
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+16
-13
No files found.
src/diffusers/models/unet_grad_tts.py
View file @
c352faea
...
@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
...
@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
)
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
mid_dim
,
temb_channels
=
dim
,
resnet_groups
=
8
,
resnet_pre_norm
=
False
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"mish"
,
attention_layer_type
=
"linear"
,
)
self
.
mid_block1
=
ResnetBlock2D
(
self
.
mid_block1
=
ResnetBlock2D
(
in_channels
=
mid_dim
,
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
out_channels
=
mid_dim
,
...
@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
self
.
mid
.
resnet_1
=
self
.
mid_block1
# self.mid = UNetMidBlock2D
self
.
mid
.
attn
=
self
.
mid_attn
self
.
mid
.
resnet_2
=
self
.
mid_block2
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
...
@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks
=
masks
[:
-
1
]
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
t
,
mask_mid
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid
(
x
,
t
,
mask
=
mask_mid
)
x
=
self
.
mid_block2
(
x
,
t
,
mask_mid
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
mask_up
=
masks
.
pop
()
...
...
src/diffusers/models/unet_new.py
View file @
c352faea
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
from
torch
import
nn
from
torch
import
nn
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.attention
import
AttentionBlock
,
LinearAttention
,
SpatialTransformer
from
.resnet
import
ResnetBlock2D
from
.resnet
import
ResnetBlock2D
...
@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
...
@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
,
dropout
:
float
=
0.0
,
resnet_eps
:
float
=
1e-6
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
attention_layer_type
:
str
=
"self"
,
attn_num_heads
=
1
,
attn_num_heads
=
1
,
attn_num_head_channels
=
None
,
attn_num_head_channels
=
None
,
...
@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm
=
resnet_time_scale_shift
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
if
attention_layer_type
==
"self"
:
if
attention_layer_type
==
"self"
:
...
@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
)
)
elif
attention_layer_type
==
"spatial"
:
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
(
self
.
attn
=
SpatialTransformer
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_heads
,
attn_num_head_channels
,
attn_num_head_channels
,
depth
=
attn_depth
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
context_dim
=
attn_encoder_channels
,
),
)
)
elif
attention_layer_type
==
"linear"
:
self
.
attn
=
LinearAttention
(
in_channels
)
self
.
resnet_2
=
ResnetBlock2D
(
self
.
resnet_2
=
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm
=
resnet_time_scale_shift
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
# TODO(Patrick) - delete all of the following code
# TODO(Patrick) - delete all of the following code
...
@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
...
@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
eps
=
resnet_eps
,
eps
=
resnet_eps
,
)
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
self
.
resnet_1
=
self
.
block_1
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
)
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
,
mask
=
mask
)
if
encoder_states
is
None
:
if
encoder_states
is
None
:
hidden_states
=
self
.
attn
(
hidden_states
)
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
,
mask
=
mask
)
return
hidden_states
return
hidden_states
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment