Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
5e036921
Unverified
Commit
5e036921
authored
Dec 07, 2022
by
Pedro Cuenca
Committed by
GitHub
Dec 07, 2022
Browse files
Make cross-attention check more robust (#1560)
* Make cross-attention check more robust. * Fix copies.
parent
bea7eb43
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
4 deletions
+10
-4
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+3
-0
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+2
-2
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+5
-2
No files found.
src/diffusers/models/unet_2d_blocks.py
View file @
5e036921
...
@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
...
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
...
src/diffusers/models/unet_2d_condition.py
View file @
5e036921
...
@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down
# 3. down
down_block_res_samples
=
(
sample
,)
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"attention
s
"
)
and
downsample_block
.
attentions
is
not
None
:
if
hasattr
(
downsample_block
,
"
has_cross_
attention"
)
and
downsample_block
.
has_cross_attention
:
sample
,
res_samples
=
downsample_block
(
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if
not
is_final_block
and
forward_upsample_size
:
if
not
is_final_block
and
forward_upsample_size
:
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
if
hasattr
(
upsample_block
,
"attention
s
"
)
and
upsample_block
.
attentions
is
not
None
:
if
hasattr
(
upsample_block
,
"
has_cross_
attention"
)
and
upsample_block
.
has_cross_attention
:
sample
=
upsample_block
(
sample
=
upsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
5e036921
...
@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down
# 3. down
down_block_res_samples
=
(
sample
,)
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"attention
s
"
)
and
downsample_block
.
attentions
is
not
None
:
if
hasattr
(
downsample_block
,
"
has_cross_
attention"
)
and
downsample_block
.
has_cross_attention
:
sample
,
res_samples
=
downsample_block
(
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
not
is_final_block
and
forward_upsample_size
:
if
not
is_final_block
and
forward_upsample_size
:
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
if
hasattr
(
upsample_block
,
"attention
s
"
)
and
upsample_block
.
attentions
is
not
None
:
if
hasattr
(
upsample_block
,
"
has_cross_
attention"
)
and
upsample_block
.
has_cross_attention
:
sample
=
upsample_block
(
sample
=
upsample_block
(
hidden_states
=
sample
,
hidden_states
=
sample
,
temb
=
emb
,
temb
=
emb
,
...
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
...
@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
...
@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
...
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
...
@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
has_cross_attention
=
True
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_num_head_channels
=
attn_num_head_channels
self
.
attn_num_head_channels
=
attn_num_head_channels
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment