Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5e036921
Unverified
Commit
5e036921
authored
Dec 07, 2022
by
Pedro Cuenca
Committed by
GitHub
Dec 07, 2022
Browse files
Make cross-attention check more robust (#1560)
* Make cross-attention check more robust. * Fix copies.
parent
bea7eb43
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
4 deletions
+10
-4
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+3
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+2
-2
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+5
-2
No files found.
src/diffusers/models/unet_2d_blocks.py
View file @
5e036921
...
@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
...
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
...
src/diffusers/models/unet_2d_condition.py
View file @
5e036921
...
@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down
# 3. down
down_block_res_samples
=
(
sample
,)
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"attention
s
"
)
and
downsample_block
.
attentions
is
not
None
:
if
hasattr
(
downsample_block
,
"
has_cross_
attention"
)
and
downsample_block
.
has_cross_attention
:
sample
,
res_samples
=
downsample_block
(
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if
not
is_final_block
and
forward_upsample_size
:
if
not
is_final_block
and
forward_upsample_size
:
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
if
hasattr
(
upsample_block
,
"attention
s
"
)
and
upsample_block
.
attentions
is
not
None
:
if
hasattr
(
upsample_block
,
"
has_cross_
attention"
)
and
upsample_block
.
has_cross_attention
:
sample
=
upsample_block
(
sample
=
upsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
5e036921
...
@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down
# 3. down
down_block_res_samples
=
(
sample
,)
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"attention
s
"
)
and
downsample_block
.
attentions
is
not
None
:
if
hasattr
(
downsample_block
,
"
has_cross_
attention"
)
and
downsample_block
.
has_cross_attention
:
sample
,
res_samples
=
downsample_block
(
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
not
is_final_block
and
forward_upsample_size
:
if
not
is_final_block
and
forward_upsample_size
:
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
if
hasattr
(
upsample_block
,
"attention
s
"
)
and
upsample_block
.
attentions
is
not
None
:
if
hasattr
(
upsample_block
,
"
has_cross_
attention"
)
and
upsample_block
.
has_cross_attention
:
sample
=
upsample_block
(
sample
=
upsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
...
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
...
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
...
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment