Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
8aed37c1
Commit
8aed37c1
authored
Jul 12, 2022
by
Patrick von Platen
Browse files
some more refactor
parent
06c79730
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
54 deletions
+26
-54
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+24
-52
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-2
No files found.
src/diffusers/models/unet_unconditional.py
View file @
8aed37c1
...
...
@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def
init_for_ldm
(
self
,
dims
,
in_channels
,
model_channels
,
channel_mult
,
...
...
@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
...
...
@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
resnet_input_channels
=
(
224
,
224
,
448
,
672
),
resnet_output_channels
=
(
224
,
448
,
672
,
896
),
conv_resample
=
True
,
dims
=
2
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
num_head_channels
=
32
,
):
super
().
__init__
()
...
...
@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
model_channels
=
model_channels
,
resnet_input_channels
=
resnet_input_channels
,
resnet_output_channels
=
resnet_output_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
num_classes
=
num_classes
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_heads_upsample
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_depth
=
transformer_depth
,
context_dim
=
context_dim
,
n_embed
=
n_embed
,
legacy
=
legacy
,
)
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
dtype_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model
_channels
*
4
time_embed_dim
=
resnet_input
_channels
[
0
]
*
4
# ======================== Input ===================
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
model
_channels
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
resnet_input
_channels
[
0
]
,
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model
_channels
,
time_embed_dim
),
nn
.
Linear
(
resnet_input
_channels
[
0
]
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
[
model_channels
*
mult
for
mult
in
[
1
]
+
list
(
channel_mult
[:
-
1
])]
output_channels
=
[
model_channels
*
mult
for
mult
in
channel
_mult
]
input_channels
=
list
(
resnet_input_channels
)
output_channels
=
list
(
resnet_output_
channel
s
)
ds_new
=
1
self
.
downsample_blocks
=
nn
.
ModuleList
([])
...
...
@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
resnet_eps
=
1e-5
,
resnet_act_fn
=
"silu"
,
resnet_time_scale_shift
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
# ======================== Up =====================
# input_channels = [model_channels * mult for mult in channel_mult]
# output_channels = [model_channels * mult for mult in channel_mult]
self
.
upsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
is_final_block
=
i
==
len
(
input_channels
)
-
1
...
...
@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
model
_channels
,
out_channels
,
3
,
padding
=
1
),
nn
.
Conv2d
(
resnet_input
_channels
[
0
]
,
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth
=
1
context_dim
=
None
legacy
=
True
num_heads
=
-
1
model_channels
=
resnet_input_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
resnet_output_channels
])
self
.
init_for_ldm
(
dims
,
in_channels
,
model_channels
,
channel_mult
,
...
...
@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
resnet_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
# 2. pre-process sample
sample
=
sample
.
type
(
self
.
dtype_
)
#
sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
...
...
tests/test_modeling_utils.py
View file @
8aed37c1
...
...
@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"resnet_input_channels"
:
[
32
,
32
],
"resnet_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment