Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c1f5855a
"StyleTextRec/utils/logging.py" did not exist on "596947758fb02f49a151ef02d8beebd3121ee9e3"
Commit
c1f5855a
authored
Mar 03, 2023
by
comfyanonymous
Browse files
Make some cross attention functions work on the CPU.
parent
1a612e1c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
20 deletions
+24
-20
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+5
-12
comfy/model_management.py
comfy/model_management.py
+19
-8
No files found.
comfy/ldm/modules/attention.py
View file @
c1f5855a
...
@@ -9,6 +9,8 @@ from typing import Optional, Any
...
@@ -9,6 +9,8 @@ from typing import Optional, Any
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
.sub_quadratic_attention
import
efficient_dot_product_attention
from
.sub_quadratic_attention
import
efficient_dot_product_attention
import
model_management
try
:
try
:
import
xformers
import
xformers
import
xformers.ops
import
xformers.ops
...
@@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
_
,
_
,
k_tokens
=
key_t
.
shape
_
,
_
,
k_tokens
=
key_t
.
shape
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
qk_matmul_size_bytes
=
batch_x_heads
*
bytes_per_token
*
q_tokens
*
k_tokens
stats
=
torch
.
cuda
.
memory_stats
(
query
.
device
)
mem_free_total
,
mem_free_torch
=
model_management
.
get_free_memory
(
query
.
device
,
True
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
chunk_threshold_bytes
=
mem_free_torch
*
0.5
#Using only this seems to work better on AMD
chunk_threshold_bytes
=
mem_free_torch
*
0.5
#Using only this seems to work better on AMD
kv_chunk_size_min
=
None
kv_chunk_size_min
=
None
...
@@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
stats
=
torch
.
cuda
.
memory_stats
(
q
.
device
)
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
gb
=
1024
**
3
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
q
.
element_size
()
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
1
]
*
q
.
element_size
()
...
...
comfy/model_management.py
View file @
c1f5855a
...
@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
...
@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
return
model
return
model
def
get_free_memory
():
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
if
dev
is
None
:
dev
=
torch
.
cuda
.
current_device
()
dev
=
torch
.
cuda
.
current_device
()
if
hasattr
(
dev
,
'type'
)
and
dev
.
type
==
'cpu'
:
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_torch
=
mem_reserved
-
mem_active
return
mem_free_cuda
+
mem_free_torch
mem_free_total
=
mem_free_cuda
+
mem_free_torch
if
torch_free_too
:
return
(
mem_free_total
,
mem_free_torch
)
else
:
return
mem_free_total
def
maximum_batch_area
():
def
maximum_batch_area
():
global
vram_state
global
vram_state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment