Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e6bc42df
Commit
e6bc42df
authored
Oct 22, 2023
by
comfyanonymous
Browse files
Make sub_quad and split work with hypertile.
parent
8cfce083
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
12 deletions
+29
-12
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+29
-12
No files found.
comfy/ldm/modules/attention.py
View file @
e6bc42df
...
...
@@ -124,11 +124,14 @@ def attention_basic(q, k, v, heads, mask=None):
def
attention_sub_quad
(
query
,
key
,
value
,
heads
,
mask
=
None
):
scale
=
(
query
.
shape
[
-
1
]
//
heads
)
**
-
0.5
query
=
query
.
unflatten
(
-
1
,
(
heads
,
-
1
)).
transpose
(
1
,
2
).
flatten
(
end_dim
=
1
)
key_t
=
key
.
transpose
(
1
,
2
).
unflatten
(
1
,
(
heads
,
-
1
)).
flatten
(
end_dim
=
1
)
del
key
value
=
value
.
unflatten
(
-
1
,
(
heads
,
-
1
)).
transpose
(
1
,
2
).
flatten
(
end_dim
=
1
)
b
,
_
,
dim_head
=
query
.
shape
dim_head
//=
heads
scale
=
dim_head
**
-
0.5
query
=
query
.
unsqueeze
(
3
).
reshape
(
b
,
-
1
,
heads
,
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
*
heads
,
-
1
,
dim_head
)
value
=
value
.
unsqueeze
(
3
).
reshape
(
b
,
-
1
,
heads
,
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
*
heads
,
-
1
,
dim_head
)
key
=
key
.
unsqueeze
(
3
).
reshape
(
b
,
-
1
,
heads
,
dim_head
).
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
*
heads
,
dim_head
,
-
1
)
dtype
=
query
.
dtype
upcast_attention
=
_ATTN_PRECISION
==
"fp32"
and
query
.
dtype
!=
torch
.
float32
...
...
@@ -137,7 +140,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
else
:
bytes_per_token
=
torch
.
finfo
(
query
.
dtype
).
bits
//
8
batch_x_heads
,
q_tokens
,
_
=
query
.
shape
_
,
_
,
k_tokens
=
key
_t
.
shape
_
,
_
,
k_tokens
=
key
.
shape
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
mem_free_total
,
mem_free_torch
=
model_management
.
get_free_memory
(
query
.
device
,
True
)
...
...
@@ -171,7 +174,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states
=
efficient_dot_product_attention
(
query
,
key
_t
,
key
,
value
,
query_chunk_size
=
query_chunk_size
,
kv_chunk_size
=
kv_chunk_size
,
...
...
@@ -186,9 +189,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
return
hidden_states
def
attention_split
(
q
,
k
,
v
,
heads
,
mask
=
None
):
scale
=
(
q
.
shape
[
-
1
]
//
heads
)
**
-
0.5
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
scale
=
dim_head
**
-
0.5
h
=
heads
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> (b h) n d'
,
h
=
h
),
(
q
,
k
,
v
))
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
b
,
-
1
,
heads
,
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
*
heads
,
-
1
,
dim_head
)
.
contiguous
(),
(
q
,
k
,
v
),
)
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
,
dtype
=
q
.
dtype
)
...
...
@@ -248,9 +261,13 @@ def attention_split(q, k, v, heads, mask=None):
del
q
,
k
,
v
r2
=
rearrange
(
r1
,
'(b h) n d -> b n (h d)'
,
h
=
h
)
del
r1
return
r2
r1
=
(
r1
.
unsqueeze
(
0
)
.
reshape
(
b
,
heads
,
-
1
,
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
,
-
1
,
heads
*
dim_head
)
)
return
r1
def
attention_xformers
(
q
,
k
,
v
,
heads
,
mask
=
None
):
b
,
_
,
dim_head
=
q
.
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment