Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
5e12d5c6
Unverified
Commit
5e12d5c6
authored
Jul 13, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 13, 2022
Browse files
Clean uncond unet more (#85)
* up * finished clean up * remove @
parent
8aed37c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
254 additions
and
205 deletions
+254
-205
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+68
-0
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+183
-204
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
src/diffusers/models/unet_new.py
View file @
5e12d5c6
...
@@ -19,6 +19,74 @@ from .attention import AttentionBlockNew
...
@@ -19,6 +19,74 @@ from .attention import AttentionBlockNew
from
.resnet
import
Downsample2D
,
ResnetBlock
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock
,
Upsample2D
def
get_down_block
(
down_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_downsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
):
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
def
get_up_block
(
up_block_type
,
num_layers
,
in_channels
,
next_channels
,
temb_channels
,
add_upsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
):
if
up_block_type
==
"UNetResUpBlock2D"
:
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
class
UNetMidBlock2D
(
nn
.
Module
):
class
UNetMidBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/diffusers/models/unet_unconditional.py
View file @
5e12d5c6
...
@@ -6,13 +6,7 @@ from ..modeling_utils import ModelMixin
...
@@ -6,13 +6,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
(
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
UNetMidBlock2D
,
UNetResAttnDownBlock2D
,
UNetResAttnUpBlock2D
,
UNetResDownBlock2D
,
UNetResUpBlock2D
,
)
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -39,6 +33,188 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -39,6 +33,188 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
increased efficiency.
"""
"""
def
__init__
(
self
,
image_size
,
in_channels
,
out_channels
,
num_res_blocks
,
dropout
=
0
,
block_input_channels
=
(
224
,
224
,
448
,
672
),
block_output_channels
=
(
224
,
448
,
672
,
896
),
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
num_head_channels
=
32
,
# To delete once weights are converted
attention_resolutions
=
(
8
,
4
,
2
),
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_input_channels
=
block_input_channels
,
block_output_channels
=
block_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
# (TODO(PVP) - To delete once weights are converted
attention_resolutions
=
attention_resolutions
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
dropout
=
dropout
time_embed_dim
=
block_input_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
block_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
block_input_channels
)
output_channels
=
list
(
block_output_channels
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
down_block_type
=
down_blocks
[
i
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
downsample_blocks
.
append
(
down_block
)
# ======================== Mid ====================
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
output_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
up_block_type
=
up_blocks
[
i
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
.
append
(
up_block
)
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
block_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
block_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_output_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
block_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
sample
,
res_samples
=
downsample_block
(
sample
,
emb
)
# append to tuple
down_block_res_samples
+=
res_samples
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
# pop from tuple
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
sample
=
self
.
out
(
sample
)
return
sample
def
init_for_ldm
(
def
init_for_ldm
(
self
,
self
,
in_channels
,
in_channels
,
...
@@ -252,200 +428,3 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -252,200 +428,3 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
def
__init__
(
self
,
image_size
,
in_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
resnet_input_channels
=
(
224
,
224
,
448
,
672
),
resnet_output_channels
=
(
224
,
448
,
672
,
896
),
conv_resample
=
True
,
num_head_channels
=
32
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
resnet_input_channels
=
resnet_input_channels
,
resnet_output_channels
=
resnet_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
time_embed_dim
=
resnet_input_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
resnet_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
resnet_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
resnet_input_channels
)
output_channels
=
list
(
resnet_output_channels
)
ds_new
=
1
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
is_final_block
=
i
==
len
(
input_channels
)
-
1
if
ds_new
in
attention_resolutions
:
down_block
=
UNetResAttnDownBlock2D
(
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
attn_num_head_channels
=
num_head_channels
,
)
else
:
down_block
=
UNetResDownBlock2D
(
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
)
self
.
downsample_blocks
.
append
(
down_block
)
ds_new
*=
2
ds_new
=
ds_new
/
2
# ======================== Mid ====================
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
output_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
is_final_block
=
i
==
len
(
input_channels
)
-
1
if
ds_new
in
attention_resolutions
:
up_block
=
UNetResAttnUpBlock2D
(
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
attn_num_head_channels
=
num_head_channels
,
)
else
:
up_block
=
UNetResUpBlock2D
(
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
next_channels
=
input_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
)
self
.
upsample_blocks
.
append
(
up_block
)
ds_new
/=
2
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
resnet_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
resnet_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
resnet_output_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
resnet_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
sample
,
res_samples
=
downsample_block
(
sample
,
emb
)
# append to tuple
down_block_res_samples
+=
res_samples
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
# pop from tuple
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
sample
=
self
.
out
(
sample
)
return
sample
tests/test_modeling_utils.py
View file @
5e12d5c6
...
@@ -492,10 +492,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -492,10 +492,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"
resnet
_input_channels"
:
[
32
,
32
],
"
block
_input_channels"
:
[
32
,
32
],
"resnet_output_channels"
:
[
32
,
64
],
"resnet_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment