Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b13cdbb2
Unverified
Commit
b13cdbb2
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
UNet2DModel mid_block_type (#10469)
parent
a0acbdc9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
15 deletions
+49
-15
src/diffusers/models/unets/unet_2d.py
src/diffusers/models/unets/unet_2d.py
+20
-15
tests/models/unets/test_models_unet_2d.py
tests/models/unets/test_models_unet_2d.py
+29
-0
No files found.
src/diffusers/models/unets/unet_2d.py
View file @
b13cdbb2
...
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `
UnCLIPUNetMidBlock2D
`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `
None
`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
...
@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift
:
int
=
0
,
freq_shift
:
int
=
0
,
flip_sin_to_cos
:
bool
=
True
,
flip_sin_to_cos
:
bool
=
True
,
down_block_types
:
Tuple
[
str
,
...]
=
(
"DownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
),
down_block_types
:
Tuple
[
str
,
...]
=
(
"DownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
),
mid_block_type
:
Optional
[
str
]
=
"UNetMidBlock2D"
,
up_block_types
:
Tuple
[
str
,
...]
=
(
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"UpBlock2D"
),
up_block_types
:
Tuple
[
str
,
...]
=
(
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
"UpBlock2D"
),
block_out_channels
:
Tuple
[
int
,
...]
=
(
224
,
448
,
672
,
896
),
block_out_channels
:
Tuple
[
int
,
...]
=
(
224
,
448
,
672
,
896
),
layers_per_block
:
int
=
2
,
layers_per_block
:
int
=
2
,
...
@@ -194,6 +195,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -194,6 +195,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self
.
down_blocks
.
append
(
down_block
)
self
.
down_blocks
.
append
(
down_block
)
# mid
# mid
if
mid_block_type
is
None
:
self
.
mid_block
=
None
else
:
self
.
mid_block
=
UNetMidBlock2D
(
self
.
mid_block
=
UNetMidBlock2D
(
in_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
...
@@ -322,6 +326,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -322,6 +326,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
# 4. mid
# 4. mid
if
self
.
mid_block
is
not
None
:
sample
=
self
.
mid_block
(
sample
,
emb
)
sample
=
self
.
mid_block
(
sample
,
emb
)
# 5. up
# 5. up
...
...
tests/models/unets/test_models_unet_2d.py
View file @
b13cdbb2
...
@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
...
@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape
=
inputs_dict
[
"sample"
].
shape
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_mid_block_none
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
mid_none_init_dict
,
mid_none_inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
mid_none_init_dict
[
"mid_block_type"
]
=
None
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
mid_none_model
=
self
.
model_class
(
**
mid_none_init_dict
)
mid_none_model
.
to
(
torch_device
)
mid_none_model
.
eval
()
self
.
assertIsNone
(
mid_none_model
.
mid_block
,
"Mid block should not exist."
)
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
.
to_tuple
()[
0
]
with
torch
.
no_grad
():
mid_none_output
=
mid_none_model
(
**
mid_none_inputs_dict
)
if
isinstance
(
mid_none_output
,
dict
):
mid_none_output
=
mid_none_output
.
to_tuple
()[
0
]
self
.
assertFalse
(
torch
.
allclose
(
output
,
mid_none_output
,
rtol
=
1e-3
),
"outputs should be different."
)
def
test_gradient_checkpointing_is_applied
(
self
):
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
expected_set
=
{
"AttnUpBlock2D"
,
"AttnUpBlock2D"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment