Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c1f5855a
Commit
c1f5855a
authored
Mar 03, 2023
by
comfyanonymous
Browse files
Make some cross attention functions work on the CPU.
parent
1a612e1c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
20 deletions
+24
-20
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+5
-12
comfy/model_management.py
comfy/model_management.py
+19
-8
No files found.
comfy/ldm/modules/attention.py
View file @
c1f5855a
...
@@ -9,6 +9,8 @@ from typing import Optional, Any
...
@@ -9,6 +9,8 @@ from typing import Optional, Any
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
.sub_quadratic_attention
import
efficient_dot_product_attention
from
.sub_quadratic_attention
import
efficient_dot_product_attention
import
model_management
try
:
try
:
import
xformers
import
xformers
import
xformers.ops
import
xformers.ops
...
@@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
_
,
_
,
k_tokens
=
key_t
.
shape
_
,
_
,
k_tokens
=
key_t
.
shape
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
stats
=
torch
.
cuda
.
memory_stats
(
query
.
device
)
mem_free_total
,
mem_free_torch
=
model_management
.
get_free_memory
(
query
.
device
,
True
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
chunk_threshold_bytes
=
mem_free_torch
*
0.5
#Using only this seems to work better on AMD
chunk_threshold_bytes
=
mem_free_torch
*
0.5
#Using only this seems to work better on AMD
kv_chunk_size_min
=
None
kv_chunk_size_min
=
None
...
@@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
stats
=
torch
.
cuda
.
memory_stats
(
q
.
device
)
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
gb
=
1024
**
3
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
q
.
element_size
()
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
q
.
element_size
()
...
...
comfy/model_management.py
View file @
c1f5855a
...
@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
...
@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
return
model
return
model
def
get_free_memory
():
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
dev
=
torch
.
cuda
.
current_device
()
if
dev
is
None
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
dev
=
torch
.
cuda
.
current_device
()
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
if
hasattr
(
dev
,
'type'
)
and
dev
.
type
==
'cpu'
:
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_torch
=
mem_free_total
return
mem_free_cuda
+
mem_free_torch
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
if
torch_free_too
:
return
(
mem_free_total
,
mem_free_torch
)
else
:
return
mem_free_total
def
maximum_batch_area
():
def
maximum_batch_area
():
global
vram_state
global
vram_state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment