Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5adb0a7b
Unverified
Commit
5adb0a7b
authored
Sep 09, 2022
by
Suraj Patil
Committed by
GitHub
Sep 09, 2022
Browse files
use torch.matmul instead of einsum in attnetion. (#445)
* use torch.matmul instead of einsum * fix softmax
parent
b2b3b1a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
4 deletions
+2
-4
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+2
-4
No files found.
src/diffusers/models/attention.py
View file @
5adb0a7b
...
@@ -275,11 +275,9 @@ class CrossAttention(nn.Module):
...
@@ -275,11 +275,9 @@ class CrossAttention(nn.Module):
for
i
in
range
(
hidden_states
.
shape
[
0
]
//
slice_size
):
for
i
in
range
(
hidden_states
.
shape
[
0
]
//
slice_size
):
start_idx
=
i
*
slice_size
start_idx
=
i
*
slice_size
end_idx
=
(
i
+
1
)
*
slice_size
end_idx
=
(
i
+
1
)
*
slice_size
attn_slice
=
(
attn_slice
=
torch
.
matmul
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
].
transpose
(
1
,
2
))
*
self
.
scale
torch
.
einsum
(
"b i d, b j d -> b i j"
,
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
])
*
self
.
scale
)
attn_slice
=
attn_slice
.
softmax
(
dim
=-
1
)
attn_slice
=
attn_slice
.
softmax
(
dim
=-
1
)
attn_slice
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn_slice
,
value
[
start_idx
:
end_idx
])
attn_slice
=
torch
.
matmul
(
attn_slice
,
value
[
start_idx
:
end_idx
])
hidden_states
[
start_idx
:
end_idx
]
=
attn_slice
hidden_states
[
start_idx
:
end_idx
]
=
attn_slice
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment