Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7cb924f6
Commit
7cb924f6
authored
Apr 06, 2023
by
藍+85CD
Browse files
Use separate variables instead of `vram_state`
parent
84b9c0ac
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
33 deletions
+37
-33
comfy/model_management.py
comfy/model_management.py
+37
-33
No files found.
comfy/model_management.py
View file @
7cb924f6
...
@@ -5,9 +5,9 @@ LOW_VRAM = 2
...
@@ -5,9 +5,9 @@ LOW_VRAM = 2
NORMAL_VRAM
=
3
NORMAL_VRAM
=
3
HIGH_VRAM
=
4
HIGH_VRAM
=
4
MPS
=
5
MPS
=
5
XPU
=
6
accelerate_enabled
=
False
accelerate_enabled
=
False
xpu_available
=
False
vram_state
=
NORMAL_VRAM
vram_state
=
NORMAL_VRAM
total_vram
=
0
total_vram
=
0
...
@@ -22,6 +22,11 @@ set_vram_to = NORMAL_VRAM
...
@@ -22,6 +22,11 @@ set_vram_to = NORMAL_VRAM
try
:
try
:
import
torch
import
torch
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
xpu_available
=
True
total_vram
=
torch
.
xpu
.
get_device_properties
(
torch
.
xpu
.
current_device
()).
total_memory
/
(
1024
*
1024
)
else
:
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
forced_normal_vram
=
"--normalvram"
in
sys
.
argv
forced_normal_vram
=
"--normalvram"
in
sys
.
argv
...
@@ -86,17 +91,10 @@ try:
...
@@ -86,17 +91,10 @@ try:
except
:
except
:
pass
pass
try
:
import
intel_extension_for_pytorch
as
ipex
if
torch
.
xpu
.
is_available
():
vram_state
=
XPU
except
:
pass
if
forced_cpu
:
if
forced_cpu
:
vram_state
=
CPU
vram_state
=
CPU
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
,
"XPU"
][
vram_state
])
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
][
vram_state
])
current_loaded_model
=
None
current_loaded_model
=
None
...
@@ -133,6 +131,7 @@ def load_model_gpu(model):
...
@@ -133,6 +131,7 @@ def load_model_gpu(model):
global
current_loaded_model
global
current_loaded_model
global
vram_state
global
vram_state
global
model_accelerated
global
model_accelerated
global
xpu_available
if
model
is
current_loaded_model
:
if
model
is
current_loaded_model
:
return
return
...
@@ -149,11 +148,11 @@ def load_model_gpu(model):
...
@@ -149,11 +148,11 @@ def load_model_gpu(model):
mps_device
=
torch
.
device
(
"mps"
)
mps_device
=
torch
.
device
(
"mps"
)
real_model
.
to
(
mps_device
)
real_model
.
to
(
mps_device
)
pass
pass
elif
vram_state
==
XPU
:
real_model
.
to
(
"xpu"
)
pass
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
model_accelerated
=
False
model_accelerated
=
False
if
xpu_available
:
real_model
.
to
(
"xpu"
)
else
:
real_model
.
cuda
()
real_model
.
cuda
()
else
:
else
:
if
vram_state
==
NO_VRAM
:
if
vram_state
==
NO_VRAM
:
...
@@ -161,7 +160,7 @@ def load_model_gpu(model):
...
@@ -161,7 +160,7 @@ def load_model_gpu(model):
elif
vram_state
==
LOW_VRAM
:
elif
vram_state
==
LOW_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"cuda"
)
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"xpu"
if
xpu_available
else
"cuda"
)
model_accelerated
=
True
model_accelerated
=
True
return
current_loaded_model
return
current_loaded_model
...
@@ -187,7 +186,11 @@ def load_controlnet_gpu(models):
...
@@ -187,7 +186,11 @@ def load_controlnet_gpu(models):
def
load_if_low_vram
(
model
):
def
load_if_low_vram
(
model
):
global
vram_state
global
vram_state
global
xpu_available
if
vram_state
==
LOW_VRAM
or
vram_state
==
NO_VRAM
:
if
vram_state
==
LOW_VRAM
or
vram_state
==
NO_VRAM
:
if
xpu_available
:
return
model
.
to
(
"xpu"
)
else
:
return
model
.
cuda
()
return
model
.
cuda
()
return
model
return
model
...
@@ -198,12 +201,14 @@ def unload_if_low_vram(model):
...
@@ -198,12 +201,14 @@ def unload_if_low_vram(model):
return
model
return
model
def
get_torch_device
():
def
get_torch_device
():
global
xpu_available
if
vram_state
==
MPS
:
if
vram_state
==
MPS
:
return
torch
.
device
(
"mps"
)
return
torch
.
device
(
"mps"
)
if
vram_state
==
XPU
:
return
torch
.
device
(
"xpu"
)
if
vram_state
==
CPU
:
if
vram_state
==
CPU
:
return
torch
.
device
(
"cpu"
)
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
return
torch
.
device
(
"xpu"
)
else
:
else
:
return
torch
.
cuda
.
current_device
()
return
torch
.
cuda
.
current_device
()
...
@@ -235,13 +240,15 @@ def pytorch_attention_enabled():
...
@@ -235,13 +240,15 @@ def pytorch_attention_enabled():
return
ENABLE_PYTORCH_ATTENTION
return
ENABLE_PYTORCH_ATTENTION
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
global
xpu_available
if
dev
is
None
:
if
dev
is
None
:
dev
=
get_torch_device
()
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
mem_free_torch
=
mem_free_total
elif
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'xpu'
):
else
:
if
xpu_available
:
mem_free_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
-
torch
.
xpu
.
memory_allocated
(
dev
)
mem_free_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
-
torch
.
xpu
.
memory_allocated
(
dev
)
mem_free_torch
=
mem_free_total
mem_free_torch
=
mem_free_total
else
:
else
:
...
@@ -274,12 +281,9 @@ def mps_mode():
...
@@ -274,12 +281,9 @@ def mps_mode():
global
vram_state
global
vram_state
return
vram_state
==
MPS
return
vram_state
==
MPS
def
xpu_mode
():
global
vram_state
return
vram_state
==
XPU
def
should_use_fp16
():
def
should_use_fp16
():
if
cpu_mode
()
or
mps_mode
()
or
xpu_mode
():
global
xpu_available
if
cpu_mode
()
or
mps_mode
()
or
xpu_available
:
return
False
#TODO ?
return
False
#TODO ?
if
torch
.
cuda
.
is_bf16_supported
():
if
torch
.
cuda
.
is_bf16_supported
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment