Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c1f5855a
Commit
c1f5855a
authored
Mar 03, 2023
by
comfyanonymous
Browse files
Make some cross attention functions work on the CPU.
parent
1a612e1c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
20 deletions
+24
-20
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+5
-12
comfy/model_management.py
comfy/model_management.py
+19
-8
No files found.
comfy/ldm/modules/attention.py
View file @
c1f5855a
...
...
@@ -9,6 +9,8 @@ from typing import Optional, Any
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
.sub_quadratic_attention
import
efficient_dot_product_attention
import
model_management
try
:
import
xformers
import
xformers.ops
...
...
@@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
_
,
_
,
k_tokens
=
key_t
.
shape
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
stats
=
torch
.
cuda
.
memory_stats
(
query
.
device
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
mem_free_total
,
mem_free_torch
=
model_management
.
get_free_memory
(
query
.
device
,
True
)
chunk_threshold_bytes
=
mem_free_torch
*
0.5
#Using only this seems to work better on AMD
kv_chunk_size_min
=
None
...
...
@@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
stats
=
torch
.
cuda
.
memory_stats
(
q
.
device
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
q
.
element_size
()
...
...
comfy/model_management.py
View file @
c1f5855a
...
...
@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
return
model
def
get_free_memory
():
dev
=
torch
.
cuda
.
current_device
()
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
return
mem_free_cuda
+
mem_free_torch
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
if
dev
is
None
:
dev
=
torch
.
cuda
.
current_device
()
if
hasattr
(
dev
,
'type'
)
and
dev
.
type
==
'cpu'
:
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
if
torch_free_too
:
return
(
mem_free_total
,
mem_free_torch
)
else
:
return
mem_free_total
def
maximum_batch_area
():
global
vram_state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment