Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
89fd5ed5
Commit
89fd5ed5
authored
Mar 24, 2023
by
Yurii Mazurevich
Browse files
Added MPS device support
parent
dd095efc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
4 deletions
+21
-4
comfy/model_management.py
comfy/model_management.py
+21
-4
No files found.
comfy/model_management.py
View file @
89fd5ed5
...
@@ -4,6 +4,7 @@ NO_VRAM = 1
...
@@ -4,6 +4,7 @@ NO_VRAM = 1
LOW_VRAM
=
2
LOW_VRAM
=
2
NORMAL_VRAM
=
3
NORMAL_VRAM
=
3
HIGH_VRAM
=
4
HIGH_VRAM
=
4
MPS
=
4
accelerate_enabled
=
False
accelerate_enabled
=
False
vram_state
=
NORMAL_VRAM
vram_state
=
NORMAL_VRAM
...
@@ -61,7 +62,8 @@ if "--novram" in sys.argv:
...
@@ -61,7 +62,8 @@ if "--novram" in sys.argv:
set_vram_to
=
NO_VRAM
set_vram_to
=
NO_VRAM
if
"--highvram"
in
sys
.
argv
:
if
"--highvram"
in
sys
.
argv
:
vram_state
=
HIGH_VRAM
vram_state
=
HIGH_VRAM
if
torch
.
backends
.
mps
.
is_available
():
vram_state
=
MPS
if
set_vram_to
==
LOW_VRAM
or
set_vram_to
==
NO_VRAM
:
if
set_vram_to
==
LOW_VRAM
or
set_vram_to
==
NO_VRAM
:
try
:
try
:
...
@@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
...
@@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
if
"--cpu"
in
sys
.
argv
:
if
"--cpu"
in
sys
.
argv
:
vram_state
=
CPU
vram_state
=
CPU
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
][
vram_state
])
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
][
vram_state
])
current_loaded_model
=
None
current_loaded_model
=
None
...
@@ -128,6 +130,12 @@ def load_model_gpu(model):
...
@@ -128,6 +130,12 @@ def load_model_gpu(model):
current_loaded_model
=
model
current_loaded_model
=
model
if
vram_state
==
CPU
:
if
vram_state
==
CPU
:
pass
pass
elif
vram_state
==
MPS
:
# print(inspect.getmro(real_model.__class__))
# print(dir(real_model))
mps_device
=
torch
.
device
(
"mps"
)
real_model
.
to
(
mps_device
)
pass
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
model_accelerated
=
False
model_accelerated
=
False
real_model
.
cuda
()
real_model
.
cuda
()
...
@@ -146,6 +154,9 @@ def load_controlnet_gpu(models):
...
@@ -146,6 +154,9 @@ def load_controlnet_gpu(models):
global
vram_state
global
vram_state
if
vram_state
==
CPU
:
if
vram_state
==
CPU
:
return
return
if
vram_state
==
MPS
:
return
if
vram_state
==
LOW_VRAM
or
vram_state
==
NO_VRAM
:
if
vram_state
==
LOW_VRAM
or
vram_state
==
NO_VRAM
:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
...
@@ -173,6 +184,8 @@ def unload_if_low_vram(model):
...
@@ -173,6 +184,8 @@ def unload_if_low_vram(model):
return
model
return
model
def
get_torch_device
():
def
get_torch_device
():
if
vram_state
==
MPS
:
return
torch
.
device
(
"mps"
)
if
vram_state
==
CPU
:
if
vram_state
==
CPU
:
return
torch
.
device
(
"cpu"
)
return
torch
.
device
(
"cpu"
)
else
:
else
:
...
@@ -195,7 +208,7 @@ def get_free_memory(dev=None, torch_free_too=False):
...
@@ -195,7 +208,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if
dev
is
None
:
if
dev
is
None
:
dev
=
get_torch_device
()
dev
=
get_torch_device
()
if
hasattr
(
dev
,
'type'
)
and
dev
.
type
==
'cpu'
:
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
)
:
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
mem_free_torch
=
mem_free_total
else
:
else
:
...
@@ -224,8 +237,12 @@ def cpu_mode():
...
@@ -224,8 +237,12 @@ def cpu_mode():
global
vram_state
global
vram_state
return
vram_state
==
CPU
return
vram_state
==
CPU
def
mps_mode
():
global
vram_state
return
vram_state
==
MPS
def
should_use_fp16
():
def
should_use_fp16
():
if
cpu_mode
():
if
cpu_mode
()
or
mps_mode
()
:
return
False
#TODO ?
return
False
#TODO ?
if
torch
.
cuda
.
is_bf16_supported
():
if
torch
.
cuda
.
is_bf16_supported
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment