Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
67892b5a
Commit
67892b5a
authored
Jun 02, 2023
by
comfyanonymous
Browse files
Refactor and improve model_management code related to free memory.
parent
499641eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
69 deletions
+68
-69
comfy/model_management.py
comfy/model_management.py
+66
-65
server.py
server.py
+2
-4
No files found.
comfy/model_management.py
View file @
67892b5a
import
psutil
from
enum
import
Enum
from
comfy.cli_args
import
args
import
torch
class
VRAMState
(
Enum
):
CPU
=
0
...
...
@@ -33,28 +34,67 @@ if args.directml is not None:
lowvram_available
=
False
#TODO: need to find a way to get free memory in directml before this can be enabled by default.
try
:
import
torch
if
directml_enabled
:
pass
#TODO
else
:
try
:
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
xpu_available
=
True
total_vram
=
torch
.
xpu
.
get_device_properties
(
torch
.
xpu
.
current_device
()).
total_memory
/
(
1024
*
1024
)
except
:
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
lowvram_available
and
total_vram
<=
4096
:
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
set_vram_to
=
VRAMState
.
LOW_VRAM
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
print
(
"Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram"
)
vram_state
=
VRAMState
.
HIGH_VRAM
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
xpu_available
=
True
except
:
pass
def
get_torch_device
():
global
xpu_available
global
directml_enabled
if
directml_enabled
:
global
directml_device
return
directml_device
if
vram_state
==
VRAMState
.
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
VRAMState
.
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
device
(
torch
.
cuda
.
current_device
())
def
get_total_memory
(
dev
=
None
,
torch_total_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_total
=
psutil
.
virtual_memory
().
total
mem_total_torch
=
mem_total
else
:
if
directml_enabled
:
mem_total
=
1024
*
1024
*
1024
#TODO
mem_total_torch
=
mem_total
elif
xpu_available
:
mem_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
mem_total_torch
=
mem_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
_
,
mem_total_cuda
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_total_torch
=
mem_reserved
mem_total
=
mem_total_cuda
if
torch_total_too
:
return
(
mem_total
,
mem_total_torch
)
else
:
return
mem_total
total_vram
=
get_total_memory
(
get_torch_device
())
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
print
(
"Total VRAM {:0.0f} MB, total RAM {:0.0f} MB"
.
format
(
total_vram
,
total_ram
))
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
lowvram_available
and
total_vram
<=
4096
:
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
set_vram_to
=
VRAMState
.
LOW_VRAM
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
print
(
"Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram"
)
vram_state
=
VRAMState
.
HIGH_VRAM
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
...
...
@@ -128,29 +168,17 @@ if args.cpu:
print
(
f
"Set vram state to:
{
vram_state
.
name
}
"
)
def
get_torch_device
():
global
xpu_available
global
directml_enabled
if
directml_enabled
:
global
directml_device
return
directml_device
if
vram_state
==
VRAMState
.
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
VRAMState
.
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
cuda
.
current_device
()
def
get_torch_device_name
(
device
):
if
hasattr
(
device
,
'type'
):
return
"{}"
.
format
(
device
.
type
)
return
"CUDA {}: {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
if
device
.
type
==
"cuda"
:
return
"{} {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
else
:
return
"{}"
.
format
(
device
.
type
)
else
:
return
"CUDA {}: {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
try
:
print
(
"
Using d
evice:"
,
get_torch_device_name
(
get_torch_device
()))
print
(
"
D
evice:"
,
get_torch_device_name
(
get_torch_device
()))
except
:
print
(
"Could not pick default device."
)
...
...
@@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
return
True
return
False
def
get_total_memory
(
dev
=
None
,
torch_total_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_total
=
psutil
.
virtual_memory
().
total
else
:
if
directml_enabled
:
mem_total
=
1024
*
1024
*
1024
#TODO
mem_total_torch
=
mem_total
elif
xpu_available
:
mem_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
mem_total_torch
=
mem_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
_
,
mem_total_cuda
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_total_torch
=
mem_reserved
mem_total
=
mem_total_cuda
if
torch_total_too
:
return
(
mem_total
,
mem_total_torch
)
else
:
return
mem_total
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
global
xpu_available
global
directml_enabled
...
...
server.py
View file @
67892b5a
...
...
@@ -7,7 +7,6 @@ import execution
import
uuid
import
json
import
glob
import
torch
from
PIL
import
Image
from
io
import
BytesIO
...
...
@@ -284,9 +283,8 @@ class PromptServer():
@
routes
.
get
(
"/system_stats"
)
async
def
get_queue
(
request
):
device_index
=
comfy
.
model_management
.
get_torch_device
()
device
=
torch
.
device
(
device_index
)
device_name
=
comfy
.
model_management
.
get_torch_device_name
(
device_index
)
device
=
comfy
.
model_management
.
get_torch_device
()
device_name
=
comfy
.
model_management
.
get_torch_device_name
(
device
)
vram_total
,
torch_vram_total
=
comfy
.
model_management
.
get_total_memory
(
device
,
torch_total_too
=
True
)
vram_free
,
torch_vram_free
=
comfy
.
model_management
.
get_free_memory
(
device
,
torch_free_too
=
True
)
system_stats
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment