Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ad0191b
Commit
3ad0191b
authored
Jan 06, 2024
by
comfyanonymous
Browse files
Implement attention mask on xformers.
parent
af94eb14
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+7
-4
No files found.
comfy/ldm/modules/attention.py
View file @
3ad0191b
...
...
@@ -294,11 +294,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(
q
,
k
,
v
),
)
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
)
if
mask
is
not
None
:
pad
=
8
-
q
.
shape
[
1
]
%
8
mask_out
=
torch
.
empty
([
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
1
]
+
pad
],
dtype
=
q
.
dtype
,
device
=
q
.
device
)
mask_out
[:,
:,
:
mask
.
shape
[
-
1
]]
=
mask
mask
=
mask_out
[:,
:,
:
mask
.
shape
[
-
1
]]
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
mask
)
if
exists
(
mask
):
raise
NotImplementedError
out
=
(
out
.
unsqueeze
(
0
)
.
reshape
(
b
,
heads
,
-
1
,
dim_head
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment