Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
37713e3b
Commit
37713e3b
authored
Apr 05, 2023
by
藍+85CD
Browse files
Add basic XPU device support
closed #387
parent
30f274bf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
2 deletions
+22
-2
comfy/model_management.py
comfy/model_management.py
+22
-2
No files found.
comfy/model_management.py
View file @
37713e3b
...
@@ -5,6 +5,7 @@ LOW_VRAM = 2
...
@@ -5,6 +5,7 @@ LOW_VRAM = 2
NORMAL_VRAM
=
3
NORMAL_VRAM
=
3
HIGH_VRAM
=
4
HIGH_VRAM
=
4
MPS
=
5
MPS
=
5
XPU
=
6
accelerate_enabled
=
False
accelerate_enabled
=
False
vram_state
=
NORMAL_VRAM
vram_state
=
NORMAL_VRAM
...
@@ -85,10 +86,17 @@ try:
...
@@ -85,10 +86,17 @@ try:
except
:
except
:
pass
pass
try
:
import
intel_extension_for_pytorch
if
torch
.
xpu
.
is_available
():
vram_state
=
XPU
except
:
pass
if
forced_cpu
:
if
forced_cpu
:
vram_state
=
CPU
vram_state
=
CPU
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
][
vram_state
])
print
(
"Set vram state to:"
,
[
"CPU"
,
"NO VRAM"
,
"LOW VRAM"
,
"NORMAL VRAM"
,
"HIGH VRAM"
,
"MPS"
,
"XPU"
][
vram_state
])
current_loaded_model
=
None
current_loaded_model
=
None
...
@@ -141,6 +149,9 @@ def load_model_gpu(model):
...
@@ -141,6 +149,9 @@ def load_model_gpu(model):
mps_device
=
torch
.
device
(
"mps"
)
mps_device
=
torch
.
device
(
"mps"
)
real_model
.
to
(
mps_device
)
real_model
.
to
(
mps_device
)
pass
pass
elif
vram_state
==
XPU
:
real_model
.
to
(
"xpu"
)
pass
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
elif
vram_state
==
NORMAL_VRAM
or
vram_state
==
HIGH_VRAM
:
model_accelerated
=
False
model_accelerated
=
False
real_model
.
cuda
()
real_model
.
cuda
()
...
@@ -189,6 +200,8 @@ def unload_if_low_vram(model):
...
@@ -189,6 +200,8 @@ def unload_if_low_vram(model):
def
get_torch_device
():
def
get_torch_device
():
if
vram_state
==
MPS
:
if
vram_state
==
MPS
:
return
torch
.
device
(
"mps"
)
return
torch
.
device
(
"mps"
)
if
vram_state
==
XPU
:
return
torch
.
device
(
"xpu"
)
if
vram_state
==
CPU
:
if
vram_state
==
CPU
:
return
torch
.
device
(
"cpu"
)
return
torch
.
device
(
"cpu"
)
else
:
else
:
...
@@ -228,6 +241,9 @@ def get_free_memory(dev=None, torch_free_too=False):
...
@@ -228,6 +241,9 @@ def get_free_memory(dev=None, torch_free_too=False):
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
if
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'cpu'
or
dev
.
type
==
'mps'
):
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_total
=
psutil
.
virtual_memory
().
available
mem_free_torch
=
mem_free_total
mem_free_torch
=
mem_free_total
elif
hasattr
(
dev
,
'type'
)
and
(
dev
.
type
==
'xpu'
):
mem_free_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
-
torch
.
xpu
.
memory_allocated
(
dev
)
mem_free_torch
=
mem_free_total
else
:
else
:
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_active
=
stats
[
'active_bytes.all.current'
]
...
@@ -258,8 +274,12 @@ def mps_mode():
...
@@ -258,8 +274,12 @@ def mps_mode():
global
vram_state
global
vram_state
return
vram_state
==
MPS
return
vram_state
==
MPS
def
xpu_mode
():
global
vram_state
return
vram_state
==
XPU
def
should_use_fp16
():
def
should_use_fp16
():
if
cpu_mode
()
or
mps_mode
():
if
cpu_mode
()
or
mps_mode
()
or
xpu_mode
()
:
return
False
#TODO ?
return
False
#TODO ?
if
torch
.
cuda
.
is_bf16_supported
():
if
torch
.
cuda
.
is_bf16_supported
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment