Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
aaa90173
Commit
aaa90173
authored
Jan 07, 2024
by
comfyanonymous
Browse files
Add attention mask support to sub quad attention.
parent
0c2c9fbd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
4 deletions
+27
-4
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+1
-0
comfy/ldm/modules/sub_quadratic_attention.py
comfy/ldm/modules/sub_quadratic_attention.py
+26
-4
No files found.
comfy/ldm/modules/attention.py
View file @
aaa90173
...
...
@@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min
=
kv_chunk_size_min
,
use_checkpoint
=
False
,
upcast_attention
=
upcast_attention
,
mask
=
mask
,
)
hidden_states
=
hidden_states
.
to
(
dtype
)
...
...
comfy/ldm/modules/sub_quadratic_attention.py
View file @
aaa90173
...
...
@@ -61,6 +61,7 @@ def _summarize_chunk(
value
:
Tensor
,
scale
:
float
,
upcast_attention
:
bool
,
mask
,
)
->
AttnChunk
:
if
upcast_attention
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
...
...
@@ -84,6 +85,8 @@ def _summarize_chunk(
max_score
,
_
=
torch
.
max
(
attn_weights
,
-
1
,
keepdim
=
True
)
max_score
=
max_score
.
detach
()
attn_weights
-=
max_score
if
mask
is
not
None
:
attn_weights
+=
mask
torch
.
exp
(
attn_weights
,
out
=
attn_weights
)
exp_weights
=
attn_weights
.
to
(
value
.
dtype
)
exp_values
=
torch
.
bmm
(
exp_weights
,
value
)
...
...
@@ -96,11 +99,12 @@ def _query_chunk_attention(
value
:
Tensor
,
summarize_chunk
:
SummarizeChunk
,
kv_chunk_size
:
int
,
mask
,
)
->
Tensor
:
batch_x_heads
,
k_channels_per_head
,
k_tokens
=
key_t
.
shape
_
,
_
,
v_channels_per_head
=
value
.
shape
def
chunk_scanner
(
chunk_idx
:
int
)
->
AttnChunk
:
def
chunk_scanner
(
chunk_idx
:
int
,
mask
)
->
AttnChunk
:
key_chunk
=
dynamic_slice
(
key_t
,
(
0
,
0
,
chunk_idx
),
...
...
@@ -111,10 +115,13 @@ def _query_chunk_attention(
(
0
,
chunk_idx
,
0
),
(
batch_x_heads
,
kv_chunk_size
,
v_channels_per_head
)
)
return
summarize_chunk
(
query
,
key_chunk
,
value_chunk
)
if
mask
is
not
None
:
mask
=
mask
[:,:,
chunk_idx
:
chunk_idx
+
kv_chunk_size
]
return
summarize_chunk
(
query
,
key_chunk
,
value_chunk
,
mask
=
mask
)
chunks
:
List
[
AttnChunk
]
=
[
chunk_scanner
(
chunk
)
for
chunk
in
torch
.
arange
(
0
,
k_tokens
,
kv_chunk_size
)
chunk_scanner
(
chunk
,
mask
)
for
chunk
in
torch
.
arange
(
0
,
k_tokens
,
kv_chunk_size
)
]
acc_chunk
=
AttnChunk
(
*
map
(
torch
.
stack
,
zip
(
*
chunks
)))
chunk_values
,
chunk_weights
,
chunk_max
=
acc_chunk
...
...
@@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
value
:
Tensor
,
scale
:
float
,
upcast_attention
:
bool
,
mask
,
)
->
Tensor
:
if
upcast_attention
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
...
...
@@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
beta
=
0
,
)
if
mask
is
not
None
:
attn_scores
+=
mask
try
:
attn_probs
=
attn_scores
.
softmax
(
dim
=-
1
)
del
attn_scores
...
...
@@ -183,6 +193,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min
:
Optional
[
int
]
=
None
,
use_checkpoint
=
True
,
upcast_attention
=
False
,
mask
=
None
,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
...
...
@@ -209,13 +220,22 @@ def efficient_dot_product_attention(
if
kv_chunk_size_min
is
not
None
:
kv_chunk_size
=
max
(
kv_chunk_size
,
kv_chunk_size_min
)
if
mask
is
not
None
and
len
(
mask
.
shape
)
==
2
:
mask
=
mask
.
unsqueeze
(
0
)
def
get_query_chunk
(
chunk_idx
:
int
)
->
Tensor
:
return
dynamic_slice
(
query
,
(
0
,
chunk_idx
,
0
),
(
batch_x_heads
,
min
(
query_chunk_size
,
q_tokens
),
q_channels_per_head
)
)
def
get_mask_chunk
(
chunk_idx
:
int
)
->
Tensor
:
if
mask
is
None
:
return
None
chunk
=
min
(
query_chunk_size
,
q_tokens
)
return
mask
[:,
chunk_idx
:
chunk_idx
+
chunk
]
summarize_chunk
:
SummarizeChunk
=
partial
(
_summarize_chunk
,
scale
=
scale
,
upcast_attention
=
upcast_attention
)
summarize_chunk
:
SummarizeChunk
=
partial
(
checkpoint
,
summarize_chunk
)
if
use_checkpoint
else
summarize_chunk
compute_query_chunk_attn
:
ComputeQueryChunkAttn
=
partial
(
...
...
@@ -237,6 +257,7 @@ def efficient_dot_product_attention(
query
=
query
,
key_t
=
key_t
,
value
=
value
,
mask
=
mask
,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
...
...
@@ -246,6 +267,7 @@ def efficient_dot_product_attention(
query
=
get_query_chunk
(
i
*
query_chunk_size
),
key_t
=
key_t
,
value
=
value
,
mask
=
get_mask_chunk
(
i
*
query_chunk_size
)
)
for
i
in
range
(
math
.
ceil
(
q_tokens
/
query_chunk_size
))
],
dim
=
1
)
return
res
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment