Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b33b64f5
Unverified
Commit
b33b64f5
authored
Mar 08, 2024
by
Martin Müller
Committed by
GitHub
Mar 08, 2024
Browse files
Make mid block optional for flax UNet (#7083)
* make mid block optional for flax UNet * make style
parent
9d974407
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
11 deletions
+20
-11
src/diffusers/models/unets/unet_2d_condition_flax.py
src/diffusers/models/unets/unet_2d_condition_flax.py
+20
-11
No files found.
src/diffusers/models/unets/unet_2d_condition_flax.py
View file @
b33b64f5
...
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of downsample blocks to use.
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
layers_per_block (`int`, *optional*, defaults to 2):
...
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D"
,
"DownBlock2D"
,
)
)
up_block_types
:
Tuple
[
str
,
...]
=
(
"UpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
)
up_block_types
:
Tuple
[
str
,
...]
=
(
"UpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
,
"CrossAttnUpBlock2D"
)
mid_block_type
:
Optional
[
str
]
=
"UNetMidBlock2DCrossAttn"
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]]
=
False
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]]
=
False
block_out_channels
:
Tuple
[
int
,
...]
=
(
320
,
640
,
1280
,
1280
)
block_out_channels
:
Tuple
[
int
,
...]
=
(
320
,
640
,
1280
,
1280
)
layers_per_block
:
int
=
2
layers_per_block
:
int
=
2
...
@@ -252,6 +255,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -252,6 +255,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self
.
down_blocks
=
down_blocks
self
.
down_blocks
=
down_blocks
# mid
# mid
if
self
.
config
.
mid_block_type
==
"UNetMidBlock2DCrossAttn"
:
self
.
mid_block
=
FlaxUNetMidBlock2DCrossAttn
(
self
.
mid_block
=
FlaxUNetMidBlock2DCrossAttn
(
in_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
dropout
=
self
.
dropout
,
dropout
=
self
.
dropout
,
...
@@ -262,6 +266,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -262,6 +266,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
split_head_dim
=
self
.
split_head_dim
,
split_head_dim
=
self
.
split_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
elif
self
.
config
.
mid_block_type
is
None
:
self
.
mid_block
=
None
else
:
raise
ValueError
(
f
"Unexpected mid_block_type
{
self
.
config
.
mid_block_type
}
"
)
# up
# up
up_blocks
=
[]
up_blocks
=
[]
...
@@ -412,6 +420,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -412,6 +420,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
down_block_res_samples
=
new_down_block_res_samples
down_block_res_samples
=
new_down_block_res_samples
# 4. mid
# 4. mid
if
self
.
mid_block
is
not
None
:
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
if
mid_block_additional_residual
is
not
None
:
if
mid_block_additional_residual
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment