Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c352faea
Unverified
Commit
c352faea
authored
Jul 04, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 04, 2022
Browse files
Add MidBlock to Grad-TTS (#74)
Finish
parent
10798663
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
18 deletions
+33
-18
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+17
-5
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+16
-13
No files found.
src/diffusers/models/unet_grad_tts.py
View file @
c352faea
...
@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
...
@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2D
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
)
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
mid_dim
,
temb_channels
=
dim
,
resnet_groups
=
8
,
resnet_pre_norm
=
False
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"mish"
,
attention_layer_type
=
"linear"
,
)
self
.
mid_block1
=
ResnetBlock2D
(
self
.
mid_block1
=
ResnetBlock2D
(
in_channels
=
mid_dim
,
in_channels
=
mid_dim
,
out_channels
=
mid_dim
,
out_channels
=
mid_dim
,
...
@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
self
.
mid
.
resnet_1
=
self
.
mid_block1
# self.mid = UNetMidBlock2D
self
.
mid
.
attn
=
self
.
mid_attn
self
.
mid
.
resnet_2
=
self
.
mid_block2
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
...
@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks
=
masks
[:
-
1
]
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
t
,
mask_mid
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid
(
x
,
t
,
mask
=
mask_mid
)
x
=
self
.
mid_block2
(
x
,
t
,
mask_mid
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
mask_up
=
masks
.
pop
()
...
...
src/diffusers/models/unet_new.py
View file @
c352faea
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
from
torch
import
nn
from
torch
import
nn
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.attention
import
AttentionBlock
,
LinearAttention
,
SpatialTransformer
from
.resnet
import
ResnetBlock2D
from
.resnet
import
ResnetBlock2D
...
@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
...
@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
,
dropout
:
float
=
0.0
,
resnet_eps
:
float
=
1e-6
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
attention_layer_type
:
str
=
"self"
,
attn_num_heads
=
1
,
attn_num_heads
=
1
,
attn_num_head_channels
=
None
,
attn_num_head_channels
=
None
,
...
@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm
=
resnet_time_scale_shift
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
if
attention_layer_type
==
"self"
:
if
attention_layer_type
==
"self"
:
...
@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
)
)
elif
attention_layer_type
==
"spatial"
:
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
(
self
.
attn
=
SpatialTransformer
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_heads
,
attn_num_head_channels
,
attn_num_head_channels
,
depth
=
attn_depth
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
context_dim
=
attn_encoder_channels
,
),
)
)
elif
attention_layer_type
==
"linear"
:
self
.
attn
=
LinearAttention
(
in_channels
)
self
.
resnet_2
=
ResnetBlock2D
(
self
.
resnet_2
=
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm
=
resnet_time_scale_shift
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
# TODO(Patrick) - delete all of the following code
# TODO(Patrick) - delete all of the following code
...
@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
...
@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
eps
=
resnet_eps
,
eps
=
resnet_eps
,
)
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
self
.
resnet_1
=
self
.
block_1
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
)
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
,
mask
=
mask
)
if
encoder_states
is
None
:
if
encoder_states
is
None
:
hidden_states
=
self
.
attn
(
hidden_states
)
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
,
mask
=
mask
)
return
hidden_states
return
hidden_states
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment