Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8aed37c1
Commit
8aed37c1
authored
Jul 12, 2022
by
Patrick von Platen
Browse files
some more refactor
parent
06c79730
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
54 deletions
+26
-54
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+24
-52
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-2
No files found.
src/diffusers/models/unet_unconditional.py
View file @
8aed37c1
...
...
@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def
init_for_ldm
(
self
,
dims
,
in_channels
,
model_channels
,
channel_mult
,
...
...
@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
...
...
@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
resnet_input_channels
=
(
224
,
224
,
448
,
672
),
resnet_output_channels
=
(
224
,
448
,
672
,
896
),
conv_resample
=
True
,
dims
=
2
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
num_head_channels
=
32
,
):
super
().
__init__
()
...
...
@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
model_channels
=
model_channels
,
resnet_input_channels
=
resnet_input_channels
,
resnet_output_channels
=
resnet_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
num_classes
=
num_classes
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_heads_upsample
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_depth
=
transformer_depth
,
context_dim
=
context_dim
,
n_embed
=
n_embed
,
legacy
=
legacy
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model
_channels
*
4
time_embed_dim
=
resnet_input
_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
model
_channels
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
resnet_input
_channels
[
0
]
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model
_channels
,
time_embed_dim
),
nn
.
Linear
(
resnet_input
_channels
[
0
]
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
[
model_channels
*
mult
for
mult
in
[
1
]
+
list
(
channel_mult
[:
-
1
])]
output_channels
=
[
model_channels
*
mult
for
mult
in
channel
_mult
]
input_channels
=
list
(
resnet_input_channels
)
output_channels
=
list
(
resnet_output_
channel
s
)
ds_new
=
1
self
.
downsample_blocks
=
nn
.
ModuleList
([])
...
...
@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
# ======================== Up =====================
# input_channels = [model_channels * mult for mult in channel_mult]
# output_channels = [model_channels * mult for mult in channel_mult]
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
is_final_block
=
i
==
len
(
input_channels
)
-
1
...
...
@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
model
_channels
,
out_channels
,
3
,
padding
=
1
),
nn
.
Conv2d
(
resnet_input
_channels
[
0
]
,
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
resnet_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
resnet_output_channels
])
self
.
init_for_ldm
(
dims
,
in_channels
,
model_channels
,
channel_mult
,
...
...
@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
resnet_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
sample
=
sample
.
type
(
self
.
dtype_
)
#
sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
...
...
tests/test_modeling_utils.py
View file @
8aed37c1
...
...
@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"resnet_input_channels"
:
[
32
,
32
],
"resnet_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment