Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ea8d58ea
Unverified
Commit
ea8d58ea
authored
Jul 05, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 05, 2022
Browse files
[MidBlock] Fix mid block (#78)
* upload files * finish
parent
c352faea
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
99 deletions
+112
-99
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+13
-5
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+5
-5
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-3
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+83
-78
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+3
-3
No files found.
src/diffusers/models/attention.py
View file @
ea8d58ea
...
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
encoder_states
=
None
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
(
q
,
k
,
v
=
(
...
...
src/diffusers/models/unet.py
View file @
ea8d58ea
...
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
down
.
append
(
down
)
self
.
down
.
append
(
down
)
# middle
# middle
self
.
mid
=
UNetMidBlock2D
(
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
=
UNetMidBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
)
)
self
.
mid_new
.
resnets
[
0
]
=
self
.
mid
.
block_1
self
.
mid_new
.
attentions
[
0
]
=
self
.
mid
.
attn_1
self
.
mid_new
.
resnets
[
1
]
=
self
.
mid
.
block_2
# upsampling
# upsampling
self
.
up
=
nn
.
ModuleList
()
self
.
up
=
nn
.
ModuleList
()
...
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
h
=
self
.
mid
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
# h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
# h = self.mid.block_2(h, temb)
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_glide.py
View file @
ea8d58ea
...
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
)
h
=
self
.
mid
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
...
...
src/diffusers/models/unet_grad_tts.py
View file @
ea8d58ea
...
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
...
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
self
.
fn
=
fn
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
return
self
.
fn
(
x
)
*
self
.
g
return
self
.
fn
(
x
,
encoder_out
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
class
Block
(
torch
.
nn
.
Module
):
...
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
self
.
mid
.
resnet
_1
=
self
.
mid_block1
self
.
mid
.
resnet
s
[
0
]
=
self
.
mid_block1
self
.
mid
.
att
n
=
self
.
mid_attn
self
.
mid
.
att
entions
[
0
]
=
self
.
mid_attn
self
.
mid
.
resnet
_2
=
self
.
mid_block2
self
.
mid
.
resnet
s
[
1
]
=
self
.
mid_block2
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
...
...
src/diffusers/models/unet_ldm.py
View file @
ea8d58ea
...
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm
=
True
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
...
src/diffusers/models/unet_new.py
View file @
ea8d58ea
...
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
:
int
,
in_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_blocks
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
...
@@ -41,7 +42,9 @@ class UNetMidBlock2D(nn.Module):
...
@@ -41,7 +42,9 @@ class UNetMidBlock2D(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
resnet_1
=
ResnetBlock2D
(
# there is always at least one resnet
resnets
=
[
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -52,9 +55,13 @@ class UNetMidBlock2D(nn.Module):
...
@@ -52,9 +55,13 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
)
)
]
attentions
=
[]
for
_
in
range
(
num_blocks
):
if
attention_layer_type
==
"self"
:
if
attention_layer_type
==
"self"
:
self
.
attn
=
AttentionBlock
(
attentions
.
append
(
AttentionBlock
(
in_channels
,
in_channels
,
num_heads
=
attn_num_heads
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
...
@@ -62,17 +69,22 @@ class UNetMidBlock2D(nn.Module):
...
@@ -62,17 +69,22 @@ class UNetMidBlock2D(nn.Module):
overwrite_qkv
=
overwrite_qkv
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
)
)
)
elif
attention_layer_type
==
"spatial"
:
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
SpatialTransformer
(
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_heads
,
attn_num_head_channels
,
attn_num_head_channels
,
depth
=
attn_depth
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
context_dim
=
attn_encoder_channels
,
)
)
)
elif
attention_layer_type
==
"linear"
:
elif
attention_layer_type
==
"linear"
:
self
.
attn
=
LinearAttention
(
in_channels
)
attentions
.
append
(
LinearAttention
(
in_channels
)
)
self
.
resnet_2
=
ResnetBlock2D
(
resnets
.
append
(
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -83,49 +95,42 @@ class UNetMidBlock2D(nn.Module):
...
@@ -83,49 +95,42 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
)
)
# TODO(Patrick) - delete all of the following code
self
.
is_overwritten
=
False
self
.
overwrite_unet
=
overwrite_unet
if
self
.
overwrite_unet
:
block_in
=
in_channels
self
.
temb_ch
=
temb_channels
self
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
self
.
attn_1
=
AttentionBlock
(
block_in
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
True
,
)
self
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
,
mask
=
mask
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
if
encoder_states
is
None
:
return
hidden_states
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
,
mask
=
mask
)
return
hidden_states
# class UNetResAttnDownBlock(nn.Module):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# temb_channels: int,
# dropout: float = 0.0,
# resnet_eps: float = 1e-6,
# resnet_time_scale_shift: str = "default",
# resnet_act_fn: str = "swish",
# resnet_groups: int = 32,
# resnet_pre_norm: bool = True,
# attention_layer_type: str = "self",
# attn_num_heads=1,
# attn_num_head_channels=None,
# attn_encoder_channels=None,
# attn_dim_head=None,
# attn_depth=None,
# output_scale_factor=1.0,
# overwrite_qkv=False,
# overwrite_unet=False,
# ):
#
# self.resents =
src/diffusers/models/unet_sde_score_estimation.py
View file @
ea8d58ea
...
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
)
)
)
)
self
.
mid
.
resnet
_1
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
resnet
s
[
0
]
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
att
n
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
att
entions
[
0
]
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet
_2
=
modules
[
len
(
modules
)
-
1
]
self
.
mid
.
resnet
s
[
1
]
=
modules
[
len
(
modules
)
-
1
]
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment