Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
67892b5a
Commit
67892b5a
authored
Jun 02, 2023
by
comfyanonymous
Browse files
Refactor and improve model_management code related to free memory.
parent
499641eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
69 deletions
+68
-69
comfy/model_management.py
comfy/model_management.py
+66
-65
server.py
server.py
+2
-4
No files found.
comfy/model_management.py
View file @
67892b5a
import
psutil
import
psutil
from
enum
import
Enum
from
enum
import
Enum
from
comfy.cli_args
import
args
from
comfy.cli_args
import
args
import
torch
class
VRAMState
(
Enum
):
class
VRAMState
(
Enum
):
CPU
=
0
CPU
=
0
...
@@ -33,28 +34,67 @@ if args.directml is not None:
...
@@ -33,28 +34,67 @@ if args.directml is not None:
lowvram_available
=
False
#TODO: need to find a way to get free memory in directml before this can be enabled by default.
lowvram_available
=
False
#TODO: need to find a way to get free memory in directml before this can be enabled by default.
try
:
try
:
import
torch
import
intel_extension_for_pytorch
as
ipex
if
directml_enabled
:
if
torch
.
xpu
.
is_available
():
pass
#TODO
xpu_available
=
True
else
:
try
:
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
xpu_available
=
True
total_vram
=
torch
.
xpu
.
get_device_properties
(
torch
.
xpu
.
current_device
()).
total_memory
/
(
1024
*
1024
)
except
:
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
lowvram_available
and
total_vram
<=
4096
:
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
set_vram_to
=
VRAMState
.
LOW_VRAM
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
print
(
"Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram"
)
vram_state
=
VRAMState
.
HIGH_VRAM
except
:
except
:
pass
pass
def
get_torch_device
():
global
xpu_available
global
directml_enabled
if
directml_enabled
:
global
directml_device
return
directml_device
if
vram_state
==
VRAMState
.
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
VRAMState
.
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
device
(
torch
.
cuda
.
current_device
())
def
get_total_memory
(
dev
=
None
,
torch_total_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_total
=
psutil
.
virtual_memory
().
total
mem_total_torch
=
mem_total
else
:
if
directml_enabled
:
mem_total
=
1024
*
1024
*
1024
#TODO
mem_total_torch
=
mem_total
elif
xpu_available
:
mem_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
mem_total_torch
=
mem_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
_
,
mem_total_cuda
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_total_torch
=
mem_reserved
mem_total
=
mem_total_cuda
if
torch_total_too
:
return
(
mem_total
,
mem_total_torch
)
else
:
return
mem_total
total_vram
=
get_total_memory
(
get_torch_device
())
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
print
(
"Total VRAM {:0.0f} MB, total RAM {:0.0f} MB"
.
format
(
total_vram
,
total_ram
))
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
lowvram_available
and
total_vram
<=
4096
:
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
set_vram_to
=
VRAMState
.
LOW_VRAM
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
print
(
"Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram"
)
vram_state
=
VRAMState
.
HIGH_VRAM
try
:
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
except
:
...
@@ -128,29 +168,17 @@ if args.cpu:
...
@@ -128,29 +168,17 @@ if args.cpu:
print
(
f
"Set vram state to:
{
vram_state
.
name
}
"
)
print
(
f
"Set vram state to:
{
vram_state
.
name
}
"
)
def
get_torch_device
():
global
xpu_available
global
directml_enabled
if
directml_enabled
:
global
directml_device
return
directml_device
if
vram_state
==
VRAMState
.
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
VRAMState
.
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
cuda
.
current_device
()
def
get_torch_device_name
(
device
):
def
get_torch_device_name
(
device
):
if
hasattr
(
device
,
'type'
):
if
hasattr
(
device
,
'type'
):
return
"{}"
.
format
(
device
.
type
)
if
device
.
type
==
"cuda"
:
return
"CUDA {}: {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
return
"{} {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
else
:
return
"{}"
.
format
(
device
.
type
)
else
:
return
"CUDA {}: {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
try
:
try
:
print
(
"
Using d
evice:"
,
get_torch_device_name
(
get_torch_device
()))
print
(
"
D
evice:"
,
get_torch_device_name
(
get_torch_device
()))
except
:
except
:
print
(
"Could not pick default device."
)
print
(
"Could not pick default device."
)
...
@@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
...
@@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
return
True
return
True
return
False
return
False
def
get_total_memory
(
dev
=
None
,
torch_total_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_total
=
psutil
.
virtual_memory
().
total
else
:
if
directml_enabled
:
mem_total
=
1024
*
1024
*
1024
#TODO
mem_total_torch
=
mem_total
elif
xpu_available
:
mem_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
mem_total_torch
=
mem_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
_
,
mem_total_cuda
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_total_torch
=
mem_reserved
mem_total
=
mem_total_cuda
if
torch_total_too
:
return
(
mem_total
,
mem_total_torch
)
else
:
return
mem_total
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
global
xpu_available
global
xpu_available
global
directml_enabled
global
directml_enabled
...
...
server.py
View file @
67892b5a
...
@@ -7,7 +7,6 @@ import execution
...
@@ -7,7 +7,6 @@ import execution
import
uuid
import
uuid
import
json
import
json
import
glob
import
glob
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
io
import
BytesIO
from
io
import
BytesIO
...
@@ -284,9 +283,8 @@ class PromptServer():
...
@@ -284,9 +283,8 @@ class PromptServer():
@
routes
.
get
(
"/system_stats"
)
@
routes
.
get
(
"/system_stats"
)
async
def
get_queue
(
request
):
async
def
get_queue
(
request
):
device_index
=
comfy
.
model_management
.
get_torch_device
()
device
=
comfy
.
model_management
.
get_torch_device
()
device
=
torch
.
device
(
device_index
)
device_name
=
comfy
.
model_management
.
get_torch_device_name
(
device
)
device_name
=
comfy
.
model_management
.
get_torch_device_name
(
device_index
)
vram_total
,
torch_vram_total
=
comfy
.
model_management
.
get_total_memory
(
device
,
torch_total_too
=
True
)
vram_total
,
torch_vram_total
=
comfy
.
model_management
.
get_total_memory
(
device
,
torch_total_too
=
True
)
vram_free
,
torch_vram_free
=
comfy
.
model_management
.
get_free_memory
(
device
,
torch_free_too
=
True
)
vram_free
,
torch_vram_free
=
comfy
.
model_management
.
get_free_memory
(
device
,
torch_free_too
=
True
)
system_stats
=
{
system_stats
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment