Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
8aed37c1
Commit
8aed37c1
authored
Jul 12, 2022
by
Patrick von Platen
Browse files
some more refactor
parent
06c79730
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
54 deletions
+26
-54
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+24
-52
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-2
No files found.
src/diffusers/models/unet_unconditional.py
View file @
8aed37c1
...
@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def
init_for_ldm
(
def
init_for_ldm
(
self
,
self
,
dims
,
in_channels
,
in_channels
,
model_channels
,
model_channels
,
channel_mult
,
channel_mult
,
...
@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
)
...
@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
,
self
,
image_size
,
image_size
,
in_channels
,
in_channels
,
model_channels
,
out_channels
,
out_channels
,
num_res_blocks
,
num_res_blocks
,
attention_resolutions
,
attention_resolutions
,
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
resnet_input_channels
=
(
224
,
224
,
448
,
672
),
resnet_output_channels
=
(
224
,
448
,
672
,
896
),
conv_resample
=
True
,
conv_resample
=
True
,
dims
=
2
,
num_head_channels
=
32
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
model_channels
=
model_channels
,
resnet_input_channels
=
resnet_input_channels
,
resnet_output_channels
=
resnet_output_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
conv_resample
=
conv_resample
,
dims
=
dims
,
num_classes
=
num_classes
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_heads_upsample
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_depth
=
transformer_depth
,
context_dim
=
context_dim
,
n_embed
=
n_embed
,
legacy
=
legacy
,
)
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model
_channels
*
4
time_embed_dim
=
resnet_input
_channels
[
0
]
*
4
# ======================== Input ===================
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
model
_channels
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
resnet_input
_channels
[
0
]
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model
_channels
,
time_embed_dim
),
nn
.
Linear
(
resnet_input
_channels
[
0
]
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
)
# ======================== Down ====================
# ======================== Down ====================
input_channels
=
[
model_channels
*
mult
for
mult
in
[
1
]
+
list
(
channel_mult
[:
-
1
])]
input_channels
=
list
(
resnet_input_channels
)
output_channels
=
[
model_channels
*
mult
for
mult
in
channel
_mult
]
output_channels
=
list
(
resnet_output_
channel
s
)
ds_new
=
1
ds_new
=
1
self
.
downsample_blocks
=
nn
.
ModuleList
([])
self
.
downsample_blocks
=
nn
.
ModuleList
([])
...
@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
attn_num_head_channels
=
num_head_channels
,
)
)
# ======================== Up =====================
# input_channels = [model_channels * mult for mult in channel_mult]
# output_channels = [model_channels * mult for mult in channel_mult]
self
.
upsample_blocks
=
nn
.
ModuleList
([])
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
is_final_block
=
i
==
len
(
input_channels
)
-
1
is_final_block
=
i
==
len
(
input_channels
)
-
1
...
@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Conv2d
(
model
_channels
,
out_channels
,
3
,
padding
=
1
),
nn
.
Conv2d
(
resnet_input
_channels
[
0
]
,
out_channels
,
3
,
padding
=
1
),
)
)
# =========== TO DELETE AFTER CONVERSION ==========
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
resnet_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
resnet_output_channels
])
self
.
init_for_ldm
(
self
.
init_for_ldm
(
dims
,
in_channels
,
in_channels
,
model_channels
,
model_channels
,
channel_mult
,
channel_mult
,
...
@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 1. time step embeddings
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
resnet_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
# 2. pre-process sample
sample
=
sample
.
type
(
self
.
dtype_
)
#
sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
# 3. down blocks
...
...
tests/test_modeling_utils.py
View file @
8aed37c1
...
@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"image_size"
:
32
,
"image_size"
:
32
,
"in_channels"
:
4
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"resnet_input_channels"
:
[
32
,
32
],
"resnet_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment