Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5e12d5c6
Unverified
Commit
5e12d5c6
authored
Jul 13, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 13, 2022
Browse files
Clean uncond unet more (#85)
* up * finished clean up * remove @
parent
8aed37c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
254 additions
and
205 deletions
+254
-205
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+68
-0
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+183
-204
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
src/diffusers/models/unet_new.py
View file @
5e12d5c6
...
...
@@ -19,6 +19,74 @@ from .attention import AttentionBlockNew
from
.resnet
import
Downsample2D
,
ResnetBlock
,
Upsample2D
def
get_down_block
(
down_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_downsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
):
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
def
get_up_block
(
up_block_type
,
num_layers
,
in_channels
,
next_channels
,
temb_channels
,
add_upsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
):
if
up_block_type
==
"UNetResUpBlock2D"
:
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
class
UNetMidBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
src/diffusers/models/unet_unconditional.py
View file @
5e12d5c6
...
...
@@ -6,13 +6,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
(
UNetMidBlock2D
,
UNetResAttnDownBlock2D
,
UNetResAttnUpBlock2D
,
UNetResDownBlock2D
,
UNetResUpBlock2D
,
)
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
...
...
@@ -39,6 +33,188 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
def
__init__
(
self
,
image_size
,
in_channels
,
out_channels
,
num_res_blocks
,
dropout
=
0
,
block_input_channels
=
(
224
,
224
,
448
,
672
),
block_output_channels
=
(
224
,
448
,
672
,
896
),
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
num_head_channels
=
32
,
# To delete once weights are converted
attention_resolutions
=
(
8
,
4
,
2
),
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_input_channels
=
block_input_channels
,
block_output_channels
=
block_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
# (TODO(PVP) - To delete once weights are converted
attention_resolutions
=
attention_resolutions
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
dropout
=
dropout
time_embed_dim
=
block_input_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
block_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
block_input_channels
)
output_channels
=
list
(
block_output_channels
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
down_block_type
=
down_blocks
[
i
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
downsample_blocks
.
append
(
down_block
)
# ======================== Mid ====================
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
output_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
up_block_type
=
up_blocks
[
i
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
.
append
(
up_block
)
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
block_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
block_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_output_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
block_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
sample
,
res_samples
=
downsample_block
(
sample
,
emb
)
# append to tuple
down_block_res_samples
+=
res_samples
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
# pop from tuple
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
sample
=
self
.
out
(
sample
)
return
sample
def
init_for_ldm
(
self
,
in_channels
,
...
...
@@ -252,200 +428,3 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
def
__init__
(
self
,
image_size
,
in_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
resnet_input_channels
=
(
224
,
224
,
448
,
672
),
resnet_output_channels
=
(
224
,
448
,
672
,
896
),
conv_resample
=
True
,
num_head_channels
=
32
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
resnet_input_channels
=
resnet_input_channels
,
resnet_output_channels
=
resnet_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
time_embed_dim
=
resnet_input_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
resnet_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
resnet_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
resnet_input_channels
)
output_channels
=
list
(
resnet_output_channels
)
ds_new
=
1
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
is_final_block
=
i
==
len
(
input_channels
)
-
1
if
ds_new
in
attention_resolutions
:
down_block
=
UNetResAttnDownBlock2D
(
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
attn_num_head_channels
=
num_head_channels
,
)
else
:
down_block
=
UNetResDownBlock2D
(
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
)
self
.
downsample_blocks
.
append
(
down_block
)
ds_new
*=
2
ds_new
=
ds_new
/
2
# ======================== Mid ====================
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
output_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
is_final_block
=
i
==
len
(
input_channels
)
-
1
if
ds_new
in
attention_resolutions
:
up_block
=
UNetResAttnUpBlock2D
(
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
attn_num_head_channels
=
num_head_channels
,
)
else
:
up_block
=
UNetResUpBlock2D
(
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
)
self
.
upsample_blocks
.
append
(
up_block
)
ds_new
/=
2
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
resnet_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
resnet_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
resnet_output_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
resnet_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
sample
,
res_samples
=
downsample_block
(
sample
,
emb
)
# append to tuple
down_block_res_samples
+=
res_samples
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
# pop from tuple
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
sample
=
self
.
out
(
sample
)
return
sample
tests/test_modeling_utils.py
View file @
5e12d5c6
...
...
@@ -492,10 +492,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"
resnet
_input_channels"
:
[
32
,
32
],
"
block
_input_channels"
:
[
32
,
32
],
"resnet_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment