Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ea8d58ea
Unverified
Commit
ea8d58ea
authored
Jul 05, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 05, 2022
Browse files
[MidBlock] Fix mid block (#78)
* upload files * finish
parent
c352faea
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
99 deletions
+112
-99
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+13
-5
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+5
-5
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-3
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+83
-78
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+3
-3
No files found.
src/diffusers/models/attention.py
View file @
ea8d58ea
...
...
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
encoder_states
=
None
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
(
...
...
src/diffusers/models/unet.py
View file @
ea8d58ea
...
...
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
UNetMidBlock2D
(
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
=
UNetMidBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
overwrite_qkv
=
True
,
overwrite_unet
=
True
)
self
.
mid_new
.
resnets
[
0
]
=
self
.
mid
.
block_1
self
.
mid_new
.
attentions
[
0
]
=
self
.
mid
.
attn_1
self
.
mid_new
.
resnets
[
1
]
=
self
.
mid
.
block_2
# upsampling
self
.
up
=
nn
.
ModuleList
()
...
...
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
self
.
mid
(
hs
[
-
1
],
temb
)
# h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
# h = self.mid.block_2(h, temb)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_glide.py
View file @
ea8d58ea
...
...
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide
=
True
,
),
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
...
...
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
mid
dle_block
(
h
,
emb
)
h
=
self
.
mid
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
...
...
src/diffusers/models/unet_grad_tts.py
View file @
ea8d58ea
...
...
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
*
self
.
g
def
forward
(
self
,
x
,
encoder_out
=
None
):
return
self
.
fn
(
x
,
encoder_out
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
...
...
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity
=
"mish"
,
overwrite_for_grad_tts
=
True
,
)
self
.
mid
.
resnet
_1
=
self
.
mid_block1
self
.
mid
.
att
n
=
self
.
mid_attn
self
.
mid
.
resnet
_2
=
self
.
mid_block2
self
.
mid
.
resnet
s
[
0
]
=
self
.
mid_block1
self
.
mid
.
att
entions
[
0
]
=
self
.
mid_attn
self
.
mid
.
resnet
s
[
1
]
=
self
.
mid_block2
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
...
...
src/diffusers/models/unet_ldm.py
View file @
ea8d58ea
...
...
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm
=
True
,
),
)
self
.
mid
.
resnet
_1
=
self
.
middle_block
[
0
]
self
.
mid
.
att
n
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
_2
=
self
.
middle_block
[
2
]
self
.
mid
.
resnet
s
[
0
]
=
self
.
middle_block
[
0
]
self
.
mid
.
att
entions
[
0
]
=
self
.
middle_block
[
1
]
self
.
mid
.
resnet
s
[
1
]
=
self
.
middle_block
[
2
]
self
.
_feature_size
+=
ch
...
...
src/diffusers/models/unet_new.py
View file @
ea8d58ea
...
...
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_blocks
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
...
...
@@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module):
):
super
().
__init__
()
self
.
resnet_1
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
if
attention_layer_type
==
"self"
:
self
.
attn
=
AttentionBlock
(
in_channels
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
)
elif
attention_layer_type
==
"spatial"
:
self
.
attn
=
SpatialTransformer
(
attn_num_heads
,
attn_num_head_channels
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
# there is always at least one resnet
resnets
=
[
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
elif
attention_layer_type
==
"linear"
:
self
.
attn
=
LinearAttention
(
in_channels
)
]
attentions
=
[]
self
.
resnet_2
=
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
for
_
in
range
(
num_blocks
):
if
attention_layer_type
==
"self"
:
attentions
.
append
(
AttentionBlock
(
in_channels
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
overwrite_qkv
,
rescale_output_factor
=
output_scale_factor
,
)
)
elif
attention_layer_type
==
"spatial"
:
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_heads
,
attn_num_head_channels
,
depth
=
attn_depth
,
context_dim
=
attn_encoder_channels
,
)
)
elif
attention_layer_type
==
"linear"
:
attentions
.
append
(
LinearAttention
(
in_channels
))
# TODO(Patrick) - delete all of the following code
self
.
is_overwritten
=
False
self
.
overwrite_unet
=
overwrite_unet
if
self
.
overwrite_unet
:
block_in
=
in_channels
self
.
temb_ch
=
temb_channels
self
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
)
self
.
attn_1
=
AttentionBlock
(
block_in
,
num_heads
=
attn_num_heads
,
num_head_channels
=
attn_num_head_channels
,
encoder_channels
=
attn_encoder_channels
,
overwrite_qkv
=
True
,
)
self
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
eps
=
resnet_eps
,
resnets
.
append
(
ResnetBlock2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
1.0
):
if
not
self
.
is_overwritten
and
self
.
overwrite_unet
:
self
.
resnet_1
=
self
.
block_1
self
.
attn
=
self
.
attn_1
self
.
resnet_2
=
self
.
block_2
self
.
is_overwritten
=
True
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
hidden_states
=
self
.
resnet_1
(
hidden_states
,
temb
,
mask
=
mask
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
if
encoder_states
is
None
:
hidden_states
=
self
.
attn
(
hidden_states
)
else
:
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_states
)
return
hidden_states
hidden_states
=
self
.
resnet_2
(
hidden_states
,
temb
,
mask
=
mask
)
return
hidden_states
# class UNetResAttnDownBlock(nn.Module):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# temb_channels: int,
# dropout: float = 0.0,
# resnet_eps: float = 1e-6,
# resnet_time_scale_shift: str = "default",
# resnet_act_fn: str = "swish",
# resnet_groups: int = 32,
# resnet_pre_norm: bool = True,
# attention_layer_type: str = "self",
# attn_num_heads=1,
# attn_num_head_channels=None,
# attn_encoder_channels=None,
# attn_dim_head=None,
# attn_depth=None,
# output_scale_factor=1.0,
# overwrite_qkv=False,
# overwrite_unet=False,
# ):
#
# self.resents =
src/diffusers/models/unet_sde_score_estimation.py
View file @
ea8d58ea
...
...
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde
=
True
,
)
)
self
.
mid
.
resnet
_1
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
att
n
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet
_2
=
modules
[
len
(
modules
)
-
1
]
self
.
mid
.
resnet
s
[
0
]
=
modules
[
len
(
modules
)
-
3
]
self
.
mid
.
att
entions
[
0
]
=
modules
[
len
(
modules
)
-
2
]
self
.
mid
.
resnet
s
[
1
]
=
modules
[
len
(
modules
)
-
1
]
pyramid_ch
=
0
# Upsampling block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment