Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3b9969c1
Commit
3b9969c1
authored
Feb 17, 2024
by
comfyanonymous
Browse files
Properly fix attention masks in CLIP with batches.
parent
5b40e7a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
2 deletions
+9
-2
comfy/clip_model.py
comfy/clip_model.py
+1
-1
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+8
-1
No files found.
comfy/clip_model.py
View file @
3b9969c1
...
...
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
x
=
self
.
embeddings
(
input_tokens
)
mask
=
None
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
).
fill_
(
float
(
"-inf"
)).
triu_
(
1
)
...
...
comfy/ldm/modules/attention.py
View file @
3b9969c1
...
...
@@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
mask
=
repeat
(
mask
,
'b j -> (b h) () j'
,
h
=
h
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
else
:
sim
+=
mask
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
sim
.
shape
)
sim
.
add_
(
mask
)
# attention, what we cannot get enough of
sim
=
sim
.
softmax
(
dim
=-
1
)
...
...
@@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if
query_chunk_size
is
None
:
query_chunk_size
=
512
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
])
hidden_states
=
efficient_dot_product_attention
(
query
,
key
,
...
...
@@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
raise
RuntimeError
(
f
'Not enough memory, use lower resolution (max approx.
{
max_res
}
x
{
max_res
}
). '
f
'Need:
{
mem_required
/
64
/
gb
:
0.1
f
}
GB free, Have:
{
mem_free_total
/
gb
:
0.1
f
}
GB free'
)
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done
=
False
cleared_cache
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment