Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
14976500
Unverified
Commit
14976500
authored
Jan 26, 2023
by
Will Berman
Committed by
GitHub
Jan 26, 2023
Browse files
fuse attention mask (#2111)
* fuse attention mask * lint * use 0 beta when no attention mask re: @Birch-san
parent
96af5bf7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
6 deletions
+13
-6
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+13
-6
No files found.
src/diffusers/models/cross_attention.py
View file @
14976500
...
...
@@ -185,17 +185,23 @@ class CrossAttention(nn.Module):
query
=
query
.
float
()
key
=
key
.
float
()
if
attention_mask
is
None
:
baddbmm_input
=
torch
.
empty
(
query
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
],
dtype
=
query
.
dtype
,
device
=
query
.
device
)
beta
=
0
else
:
baddbmm_input
=
attention_mask
beta
=
1
attention_scores
=
torch
.
baddbmm
(
torch
.
empty
(
query
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
],
dtype
=
query
.
dtype
,
device
=
query
.
device
)
,
baddbmm_input
,
query
,
key
.
transpose
(
-
1
,
-
2
),
beta
=
0
,
beta
=
beta
,
alpha
=
self
.
scale
,
)
if
attention_mask
is
not
None
:
attention_scores
=
attention_scores
+
attention_mask
if
self
.
upcast_softmax
:
attention_scores
=
attention_scores
.
float
()
...
...
@@ -228,11 +234,12 @@ class CrossAttnProcessor:
attention_mask
=
attn
.
prepare_attention_mask
(
attention_mask
,
sequence_length
)
query
=
attn
.
to_q
(
hidden_states
)
query
=
attn
.
head_to_batch_dim
(
query
)
encoder_hidden_states
=
encoder_hidden_states
if
encoder_hidden_states
is
not
None
else
hidden_states
key
=
attn
.
to_k
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
query
=
attn
.
head_to_batch_dim
(
query
)
key
=
attn
.
head_to_batch_dim
(
key
)
value
=
attn
.
head_to_batch_dim
(
value
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment