Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c4a3b09a
Unverified
Commit
c4a3b09a
authored
Aug 05, 2022
by
Suraj Patil
Committed by
GitHub
Aug 05, 2022
Browse files
[UNet2DConditionModel] add cross_attention_dim as an argument (#155)
add cross_attention_dim as an argument
parent
616c3a42
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
0 deletions
+12
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+4
-0
src/diffusers/models/unet_blocks.py
src/diffusers/models/unet_blocks.py
+8
-0
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
c4a3b09a
...
...
@@ -28,6 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
act_fn
=
"silu"
,
norm_num_groups
=
32
,
norm_eps
=
1e-5
,
cross_attention_dim
=
1280
,
attention_head_dim
=
8
,
):
super
().
__init__
()
...
...
@@ -64,6 +65,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_downsample
=
not
is_final_block
,
resnet_eps
=
norm_eps
,
resnet_act_fn
=
act_fn
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attention_head_dim
,
downsample_padding
=
downsample_padding
,
)
...
...
@@ -77,6 +79,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn
=
act_fn
,
output_scale_factor
=
mid_block_scale_factor
,
resnet_time_scale_shift
=
"default"
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attention_head_dim
,
resnet_groups
=
norm_num_groups
,
)
...
...
@@ -101,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_upsample
=
not
is_final_block
,
resnet_eps
=
norm_eps
,
resnet_act_fn
=
act_fn
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attention_head_dim
,
)
self
.
up_blocks
.
append
(
up_block
)
...
...
src/diffusers/models/unet_blocks.py
View file @
c4a3b09a
...
...
@@ -31,6 +31,7 @@ def get_down_block(
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
cross_attention_dim
=
None
,
downsample_padding
=
None
,
):
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
...
...
@@ -58,6 +59,8 @@ def get_down_block(
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"CrossAttnDownBlock2D"
:
if
cross_attention_dim
is
None
:
raise
ValueError
(
"cross_attention_dim must be specified for CrossAttnUpBlock2D"
)
return
CrossAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
...
...
@@ -67,6 +70,7 @@ def get_down_block(
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"SkipDownBlock2D"
:
...
...
@@ -115,6 +119,7 @@ def get_up_block(
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
cross_attention_dim
=
None
,
):
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpBlock2D"
:
...
...
@@ -129,6 +134,8 @@ def get_up_block(
resnet_act_fn
=
resnet_act_fn
,
)
elif
up_block_type
==
"CrossAttnUpBlock2D"
:
if
cross_attention_dim
is
None
:
raise
ValueError
(
"cross_attention_dim must be specified for CrossAttnUpBlock2D"
)
return
CrossAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
...
...
@@ -138,6 +145,7 @@ def get_up_block(
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
up_block_type
==
"AttnUpBlock2D"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment