Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ea8d58ea
Unverified
Commit
ea8d58ea
authored
Jul 05, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 05, 2022
Browse files
[MidBlock] Fix mid block (#78)
* upload files * finish
parent
c352faea
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
99 deletions
+112
-99
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+13
-5
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+5
-5
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-3
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+83
-78
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+3
-3
No files found.
src/diffusers/models/attention.py
View file @
ea8d58ea
...
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
encoder_states
=
None
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
(
q
,
k
,
v
=
(
...
...
src/diffusers/models/unet.py
View file @
ea8d58ea
...
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
down
.
append
(
down
)
self
.
down
.
append
(
down
)
# middle
# middle
self
.
mid
=
UNetMidBlock2D
(
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
=
UNetMidBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
)
)
self
.
mid_new
.
resnets
[
0
]
=
self
.
mid
.
block_1
self
.
mid_new
.
attentions
[
0
]
=
self
.
mid
.
attn_1
self
.
mid_new
.
resnets
[
1
]
=
self
.
mid
.
block_2
# upsampling
# upsampling
self
.
up
=
nn
.
ModuleList
()
self
.
up
=
nn
.
ModuleList
()
...
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
h
=
self
.
mid
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
# h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
# h = self.mid.block_2(h, temb)
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_glide.py
View file @
ea8d58ea
...
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
)
h
=
self
.
mid
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
...
...
src/diffusers/models/unet_grad_tts.py
View file @
ea8d58ea
...
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
...
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
self
.
fn
=
fn
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
return
self
.
fn
(
x
)
*
self
.
g
return
self
.
fn
(
x
,
encoder_out
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
class
Block
(
torch
.
nn
.
Module
):
...
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
overwrite_for_grad_tts
=
True
,
)
)
self
.
mid
.
resnet
_1
=
self
.
mid_block1
self
.
mid
.
resnet
s
[
0
]
=
self
.
mid_block1
self
.
mid
.
att
n
=
self
.
mid_attn
self
.
mid
.
att
entions
[
0
]
=
self
.
mid_attn
self
.
mid
.
resnet
_2
=
self
.
mid_block2
self
.
mid
.
resnet
s
[
1
]
=
self
.
mid_block2
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
self
.
ups
.
append
(
...
...
src/diffusers/models/unet_ldm.py
View file @
ea8d58ea
...
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm
=
True
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
...
src/diffusers/models/unet_new.py
View file @
ea8d58ea
...
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
:
int
,
in_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_blocks
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
...
@@ -41,7 +42,9 @@ class UNetMidBlock2D(nn.Module):
...
@@ -41,7 +42,9 @@ class UNetMidBlock2D(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
resnet_1
=
ResnetBlock2D
(
# there is always at least one resnet
resnets
=
[
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -52,9 +55,13 @@ class UNetMidBlock2D(nn.Module):
...
@@ -52,9 +55,13 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
)
)
]
attentions
=
[]
for
_
in
range
(
num_blocks
):
if
attention_layer_type
==
"self"
:
if
attention_layer_type
==
"self"
:
self
.
attn
=
AttentionBlock
(
attentions
.
append
(
AttentionBlock
(
in_channels
,
in_channels
,
num_heads
=
attn_num_heads
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
...
@@ -62,17 +69,22 @@ class UNetMidBlock2D(nn.Module):
...
@@ -62,17 +69,22 @@ class UNetMidBlock2D(nn.Module):
overwrite_qkv
=
overwrite_qkv
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
)
)
)
elif
attention_layer_type
==
"spatial"
:
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
SpatialTransformer
(
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_heads
,
attn_num_head_channels
,
attn_num_head_channels
,
depth
=
attn_depth
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
context_dim
=
attn_encoder_channels
,
)
)
)
elif
attention_layer_type
==
"linear"
:
elif
attention_layer_type
==
"linear"
:
self
.
attn
=
LinearAttention
(
in_channels
)
attentions
.
append
(
LinearAttention
(
in_channels
)
)
self
.
resnet_2
=
ResnetBlock2D
(
resnets
.
append
(
ResnetBlock2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
...
@@ -83,49 +95,42 @@ class UNetMidBlock2D(nn.Module):
...
@@ -83,49 +95,42 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor
=
output_scale_factor
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
pre_norm
=
resnet_pre_norm
,
)
)
# TODO(Patrick) - delete all of the following code
self
.
is_overwritten
=
False
self
.
overwrite_unet
=
overwrite_unet
if
self
.
overwrite_unet
:
block_in
=
in_channels
self
.
temb_ch
=
temb_channels
self
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
self
.
attn_1
=
AttentionBlock
(
block_in
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
True
,
)
self
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
,
mask
=
mask
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
if
encoder_states
is
None
:
return
hidden_states
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
,
mask
=
mask
)
return
hidden_states
# class UNetResAttnDownBlock(nn.Module):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# temb_channels: int,
# dropout: float = 0.0,
# resnet_eps: float = 1e-6,
# resnet_time_scale_shift: str = "default",
# resnet_act_fn: str = "swish",
# resnet_groups: int = 32,
# resnet_pre_norm: bool = True,
# attention_layer_type: str = "self",
# attn_num_heads=1,
# attn_num_head_channels=None,
# attn_encoder_channels=None,
# attn_dim_head=None,
# attn_depth=None,
# output_scale_factor=1.0,
# overwrite_qkv=False,
# overwrite_unet=False,
# ):
#
# self.resents =
src/diffusers/models/unet_sde_score_estimation.py
View file @
ea8d58ea
...
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
)
)
)
)
self
.
mid
.
resnet
_1
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
resnet
s
[
0
]
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
att
n
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
att
entions
[
0
]
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet
_2
=
modules
[
len
(
modules
)
-
1
]
self
.
mid
.
resnet
s
[
1
]
=
modules
[
len
(
modules
)
-
1
]
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment