Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7cb924f6
Commit
7cb924f6
authored
Apr 06, 2023
by
藍+85CD
Browse files
Use separate variables instead of `vram_state`
parent
84b9c0ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
33 deletions
+37
-33
comfy/model_management.py
comfy/model_management.py
+37
-33
No files found.
comfy/model_management.py
View file @
7cb924f6
...
...
@@ -5,9 +5,9 @@ LOW_VRAM = 2
NORMAL_VRAM
=
3
HIGH_VRAM
=
4
MPS
=
5
XPU
=
6
accelerate_enabled
=
False
xpu_available
=
False
vram_state
=
NORMAL_VRAM
total_vram
=
0
...
...
@@ -22,7 +22,12 @@ set_vram_to = NORMAL_VRAM
try
:
import
torch
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
xpu_available
=
True
total_vram
=
torch
.
xpu
.
get_device_properties
(
torch
.
xpu
.
current_device
()).
total_memory
/
(
1024
*
1024
)
else
:
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
forced_normal_vram
=
"--normalvram"
in
sys
.
argv
if
not
forced_normal_vram
and
not
forced_cpu
:
...
...
@@ -86,17 +91,10 @@ try:
except
:
pass
try
:
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
vram_state
=
XPU
except
:
pass
if
forced_cpu
:
vram_state
=
CPU
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
,
"XPU"
][
vram_state
])
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
][
vram_state
])
current_loaded_model
=
None
...
...
@@ -133,6 +131,7 @@ def load_model_gpu(model):
global
current_loaded_model
global
vram_state
global
model_accelerated
global
xpu_available
if
model
is
current_loaded_model
:
return
...
...
@@ -149,19 +148,19 @@ def load_model_gpu(model):
mps_device
=
torch
.
device
(
"mps"
)
real_model
.
to
(
mps_device
)
pass
elif
vram_state
==
XPU
:
real_model
.
to
(
"xpu"
)
pass
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
model_accelerated
=
False
real_model
.
cuda
()
if
xpu_available
:
real_model
.
to
(
"xpu"
)
else
:
real_model
.
cuda
()
else
:
if
vram_state
==
NO_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
elif
vram_state
==
LOW_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"cuda"
)
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"xpu"
if
xpu_available
else
"cuda"
)
model_accelerated
=
True
return
current_loaded_model
...
...
@@ -187,8 +186,12 @@ def load_controlnet_gpu(models):
def
load_if_low_vram
(
model
):
global
vram_state
global
xpu_available
if
vram_state
==
LOW_VRAM
or
vram_state
==
NO_VRAM
:
return
model
.
cuda
()
if
xpu_available
:
return
model
.
to
(
"xpu"
)
else
:
return
model
.
cuda
()
return
model
def
unload_if_low_vram
(
model
):
...
...
@@ -198,14 +201,16 @@ def unload_if_low_vram(model):
return
model
def
get_torch_device
():
global
xpu_available
if
vram_state
==
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
XPU
:
return
torch
.
device
(
"xpu"
)
if
vram_state
==
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
return
torch
.
cuda
.
current_device
()
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
cuda
.
current_device
()
def
get_autocast_device
(
dev
):
if
hasattr
(
dev
,
'type'
):
...
...
@@ -235,22 +240,24 @@ def pytorch_attention_enabled():
return
ENABLE_PYTORCH_ATTENTION
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
global
xpu_available
if
dev
is
None
:
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
elif
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'xpu'
):
mem_free_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
-
torch
.
xpu
.
memory_allocated
(
dev
)
mem_free_torch
=
mem_free_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
if
xpu_available
:
mem_free_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
-
torch
.
xpu
.
memory_allocated
(
dev
)
mem_free_torch
=
mem_free_total
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
if
torch_free_too
:
return
(
mem_free_total
,
mem_free_torch
)
...
...
@@ -274,12 +281,9 @@ def mps_mode():
global
vram_state
return
vram_state
==
MPS
def
xpu_mode
():
global
vram_state
return
vram_state
==
XPU
def
should_use_fp16
():
if
cpu_mode
()
or
mps_mode
()
or
xpu_mode
():
global
xpu_available
if
cpu_mode
()
or
mps_mode
()
or
xpu_available
:
return
False
#TODO ?
if
torch
.
cuda
.
is_bf16_supported
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment