Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3b9969c1
"deploy/cpp_infer/src_det/postprocess_op.cpp" did not exist on "297edfaa03ec802dd1711cd3df72f520c1950aba"
Commit
3b9969c1
authored
Feb 17, 2024
by
comfyanonymous
Browse files
Properly fix attention masks in CLIP with batches.
parent
5b40e7a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
2 deletions
+9
-2
comfy/clip_model.py
comfy/clip_model.py
+1
-1
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+8
-1
No files found.
comfy/clip_model.py
View file @
3b9969c1
...
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
...
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
x
=
self
.
embeddings
(
input_tokens
)
x
=
self
.
embeddings
(
input_tokens
)
mask
=
None
mask
=
None
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
).
fill_
(
float
(
"-inf"
)).
triu_
(
1
)
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
).
fill_
(
float
(
"-inf"
)).
triu_
(
1
)
...
...
comfy/ldm/modules/attention.py
View file @
3b9969c1
...
@@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
...
@@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
mask
=
repeat
(
mask
,
'b j -> (b h) () j'
,
h
=
h
)
mask
=
repeat
(
mask
,
'b j -> (b h) () j'
,
h
=
h
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
else
:
else
:
sim
+=
mask
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
sim
.
shape
)
sim
.
add_
(
mask
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
sim
=
sim
.
softmax
(
dim
=-
1
)
sim
=
sim
.
softmax
(
dim
=-
1
)
...
@@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
...
@@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if
query_chunk_size
is
None
:
if
query_chunk_size
is
None
:
query_chunk_size
=
512
query_chunk_size
=
512
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
])
hidden_states
=
efficient_dot_product_attention
(
hidden_states
=
efficient_dot_product_attention
(
query
,
query
,
key
,
key
,
...
@@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
...
@@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
raise
RuntimeError
(
f
'Not enough memory, use lower resolution (max approx.
{
max_res
}
x
{
max_res
}
). '
raise
RuntimeError
(
f
'Not enough memory, use lower resolution (max approx.
{
max_res
}
x
{
max_res
}
). '
f
'Need:
{
mem_required
/
64
/
gb
:
0.1
f
}
GB free, Have:
{
mem_free_total
/
gb
:
0.1
f
}
GB free'
)
f
'Need:
{
mem_required
/
64
/
gb
:
0.1
f
}
GB free, Have:
{
mem_free_total
/
gb
:
0.1
f
}
GB free'
)
if
mask
is
not
None
:
mask
=
mask
.
reshape
(
mask
.
shape
[
0
],
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]).
expand
(
-
1
,
heads
,
-
1
,
-
1
).
reshape
(
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done
=
False
first_op_done
=
False
cleared_cache
=
False
cleared_cache
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment