Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b0ab31d0
Commit
b0ab31d0
authored
May 14, 2024
by
comfyanonymous
Browse files
Refactor attention upcasting code part 1.
parent
2de3b69b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
13 deletions
+15
-13
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+15
-13
No files found.
comfy/ldm/modules/attention.py
View file @
b0ab31d0
...
...
@@ -22,9 +22,9 @@ ops = comfy.ops.disable_weight_init
# CrossAttn precision handling
if
args
.
dont_upcast_attention
:
logging
.
info
(
"disabling upcasting of attention"
)
_ATTN_PRECISION
=
"fp16"
_ATTN_PRECISION
=
None
else
:
_ATTN_PRECISION
=
"fp
32
"
_ATTN_PRECISION
=
torch
.
float
32
def
exists
(
val
):
...
...
@@ -85,7 +85,7 @@ class FeedForward(nn.Module):
def
Normalize
(
in_channels
,
dtype
=
None
,
device
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
,
device
=
device
)
def
attention_basic
(
q
,
k
,
v
,
heads
,
mask
=
None
):
def
attention_basic
(
q
,
k
,
v
,
heads
,
mask
=
None
,
attn_precision
=
None
):
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
scale
=
dim_head
**
-
0.5
...
...
@@ -101,7 +101,7 @@ def attention_basic(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
if
_ATTN_PRECISION
==
"fp
32
"
:
if
attn_precision
==
torch
.
float
32
:
sim
=
einsum
(
'b i d, b j d -> b i j'
,
q
.
float
(),
k
.
float
())
*
scale
else
:
sim
=
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
scale
...
...
@@ -135,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None):
return
out
def
attention_sub_quad
(
query
,
key
,
value
,
heads
,
mask
=
None
):
def
attention_sub_quad
(
query
,
key
,
value
,
heads
,
mask
=
None
,
attn_precision
=
None
):
b
,
_
,
dim_head
=
query
.
shape
dim_head
//=
heads
...
...
@@ -146,7 +146,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key
=
key
.
unsqueeze
(
3
).
reshape
(
b
,
-
1
,
heads
,
dim_head
).
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
*
heads
,
dim_head
,
-
1
)
dtype
=
query
.
dtype
upcast_attention
=
_ATTN_PRECISION
==
"fp
32
"
and
query
.
dtype
!=
torch
.
float32
upcast_attention
=
attn_precision
==
torch
.
float
32
and
query
.
dtype
!=
torch
.
float32
if
upcast_attention
:
bytes_per_token
=
torch
.
finfo
(
torch
.
float32
).
bits
//
8
else
:
...
...
@@ -195,7 +195,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states
=
hidden_states
.
unflatten
(
0
,
(
-
1
,
heads
)).
transpose
(
1
,
2
).
flatten
(
start_dim
=
2
)
return
hidden_states
def
attention_split
(
q
,
k
,
v
,
heads
,
mask
=
None
):
def
attention_split
(
q
,
k
,
v
,
heads
,
mask
=
None
,
attn_precision
=
None
):
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
scale
=
dim_head
**
-
0.5
...
...
@@ -214,10 +214,12 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
if
_ATTN_PRECISION
==
"fp
32
"
:
if
attn_precision
==
torch
.
float
32
:
element_size
=
4
upcast
=
True
else
:
element_size
=
q
.
element_size
()
upcast
=
False
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
element_size
...
...
@@ -251,7 +253,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size
=
q
.
shape
[
1
]
//
steps
if
(
q
.
shape
[
1
]
%
steps
)
==
0
else
q
.
shape
[
1
]
for
i
in
range
(
0
,
q
.
shape
[
1
],
slice_size
):
end
=
i
+
slice_size
if
_ATTN_PRECISION
==
"fp32"
:
if
upcast
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
s1
=
einsum
(
'b i d, b j d -> b i j'
,
q
[:,
i
:
end
].
float
(),
k
.
float
())
*
scale
else
:
...
...
@@ -302,7 +304,7 @@ try:
except
:
pass
def
attention_xformers
(
q
,
k
,
v
,
heads
,
mask
=
None
):
def
attention_xformers
(
q
,
k
,
v
,
heads
,
mask
=
None
,
attn_precision
=
None
):
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
if
BROKEN_XFORMERS
:
...
...
@@ -334,7 +336,7 @@ def attention_xformers(q, k, v, heads, mask=None):
)
return
out
def
attention_pytorch
(
q
,
k
,
v
,
heads
,
mask
=
None
):
def
attention_pytorch
(
q
,
k
,
v
,
heads
,
mask
=
None
,
attn_precision
=
None
):
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
q
,
k
,
v
=
map
(
...
...
@@ -409,9 +411,9 @@ class CrossAttention(nn.Module):
v
=
self
.
to_v
(
context
)
if
mask
is
None
:
out
=
optimized_attention
(
q
,
k
,
v
,
self
.
heads
)
out
=
optimized_attention
(
q
,
k
,
v
,
self
.
heads
,
attn_precision
=
_ATTN_PRECISION
)
else
:
out
=
optimized_attention_masked
(
q
,
k
,
v
,
self
.
heads
,
mask
)
out
=
optimized_attention_masked
(
q
,
k
,
v
,
self
.
heads
,
mask
,
attn_precision
=
_ATTN_PRECISION
)
return
self
.
to_out
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment