Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4a0c4ce4
Commit
4a0c4ce4
authored
Sep 02, 2023
by
Simon Lui
Browse files
Some fixes to generalize CUDA specific functionality to Intel or other GPUs.
parent
62efc78a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
26 deletions
+38
-26
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+1
-2
comfy/ldm/modules/diffusionmodules/util.py
comfy/ldm/modules/diffusionmodules/util.py
+17
-7
comfy/model_management.py
comfy/model_management.py
+20
-17
No files found.
comfy/ldm/modules/attention.py
View file @
4a0c4ce4
...
...
@@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module):
break
except
model_management
.
OOM_EXCEPTION
as
e
:
if
first_op_done
==
False
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
model_management
.
soft_empty_cache
()
if
cleared_cache
==
False
:
cleared_cache
=
True
print
(
"out of memory error, emptying cache and trying again"
)
...
...
comfy/ldm/modules/diffusionmodules/util.py
View file @
4a0c4ce4
...
...
@@ -15,6 +15,7 @@ import torch.nn as nn
import
numpy
as
np
from
einops
import
repeat
from
comfy
import
model_management
from
comfy.ldm.util
import
instantiate_from_config
import
comfy.ops
...
...
@@ -139,6 +140,7 @@ class CheckpointFunction(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
ctx
.
input_tensors
=
[
x
.
detach
().
requires_grad_
(
True
)
for
x
in
ctx
.
input_tensors
]
if
model_management
.
is_nvidia
():
with
torch
.
enable_grad
(),
\
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
...
...
@@ -146,6 +148,14 @@ class CheckpointFunction(torch.autograd.Function):
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
elif
model_management
.
is_intel_xpu
():
with
torch
.
enable_grad
(),
\
torch
.
xpu
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
ctx
.
input_tensors
+
ctx
.
input_params
,
...
...
comfy/model_management.py
View file @
4a0c4ce4
...
...
@@ -58,8 +58,15 @@ except:
if
args
.
cpu
:
cpu_state
=
CPUState
.
CPU
def
get_torch_device
():
def
is_intel_xpu
():
global
cpu_state
global
xpu_available
if
cpu_state
==
CPUState
.
GPU
:
if
xpu_available
:
return
True
return
False
def
get_torch_device
():
global
directml_enabled
global
cpu_state
if
directml_enabled
:
...
...
@@ -70,13 +77,12 @@ def get_torch_device():
if
cpu_state
==
CPUState
.
CPU
:
return
torch
.
device
(
"cpu"
)
else
:
if
xpu_available
:
if
is_intel_xpu
()
:
return
torch
.
device
(
"xpu"
)
else
:
return
torch
.
device
(
torch
.
cuda
.
current_device
())
def
get_total_memory
(
dev
=
None
,
torch_total_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
...
...
@@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if
directml_enabled
:
mem_total
=
1024
*
1024
*
1024
#TODO
mem_total_torch
=
mem_total
elif
xpu_available
:
elif
is_intel_xpu
()
:
stats
=
torch
.
xpu
.
memory_stats
(
dev
)
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_total
=
torch
.
xpu
.
get_device_properties
(
dev
).
total_memory
...
...
@@ -146,11 +152,11 @@ def is_nvidia():
if
cpu_state
==
CPUState
.
GPU
:
if
torch
.
version
.
cuda
:
return
True
return
False
ENABLE_PYTORCH_ATTENTION
=
args
.
use_pytorch_cross_attention
VAE_DTYPE
=
torch
.
float32
try
:
if
is_nvidia
():
torch_version
=
torch
.
version
.
__version__
...
...
@@ -162,6 +168,9 @@ try:
except
:
pass
if
is_intel_xpu
():
VAE_DTYPE
=
torch
.
bfloat16
if
args
.
fp16_vae
:
VAE_DTYPE
=
torch
.
float16
elif
args
.
bf16_vae
:
...
...
@@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY:
print
(
"Disabling smart memory management"
)
def
get_torch_device_name
(
device
):
global
xpu_available
if
hasattr
(
device
,
'type'
):
if
device
.
type
==
"cuda"
:
try
:
...
...
@@ -230,7 +238,7 @@ def get_torch_device_name(device):
return
"{} {} : {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
),
allocator_backend
)
else
:
return
"{}"
.
format
(
device
.
type
)
elif
xpu_available
:
elif
is_intel_xpu
()
:
return
"{} {}"
.
format
(
device
,
torch
.
xpu
.
get_device_name
(
device
))
else
:
return
"CUDA {}: {}"
.
format
(
device
,
torch
.
cuda
.
get_device_name
(
device
))
...
...
@@ -260,7 +268,6 @@ class LoadedModel:
return
self
.
model_memory
()
def
model_load
(
self
,
lowvram_model_memory
=
0
):
global
xpu_available
patch_model_to
=
None
if
lowvram_model_memory
==
0
:
patch_model_to
=
self
.
device
...
...
@@ -281,7 +288,7 @@ class LoadedModel:
accelerate
.
dispatch_model
(
self
.
real_model
,
device_map
=
device_map
,
main_device
=
self
.
device
)
self
.
model_accelerated
=
True
if
xpu_available
and
not
args
.
disable_ipex_optimize
:
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
return
self
.
real_model
...
...
@@ -471,12 +478,11 @@ def get_autocast_device(dev):
def
xformers_enabled
():
global
xpu_available
global
directml_enabled
global
cpu_state
if
cpu_state
!=
CPUState
.
GPU
:
return
False
if
xpu_available
:
if
is_intel_xpu
()
:
return
False
if
directml_enabled
:
return
False
...
...
@@ -503,7 +509,6 @@ def pytorch_attention_flash_attention():
return
False
def
get_free_memory
(
dev
=
None
,
torch_free_too
=
False
):
global
xpu_available
global
directml_enabled
if
dev
is
None
:
dev
=
get_torch_device
()
...
...
@@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if
directml_enabled
:
mem_free_total
=
1024
*
1024
*
1024
#TODO
mem_free_torch
=
mem_free_total
elif
xpu_available
:
elif
is_intel_xpu
()
:
stats
=
torch
.
xpu
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_allocated
=
stats
[
'allocated_bytes.all.current'
]
...
...
@@ -577,7 +582,6 @@ def is_device_mps(device):
return
False
def
should_use_fp16
(
device
=
None
,
model_params
=
0
,
prioritize_performance
=
True
):
global
xpu_available
global
directml_enabled
if
device
is
not
None
:
...
...
@@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if
cpu_mode
()
or
mps_mode
():
return
False
#TODO ?
if
xpu_available
:
if
is_intel_xpu
()
:
return
True
if
torch
.
cuda
.
is_bf16_supported
():
...
...
@@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return
True
def
soft_empty_cache
():
global
xpu_available
global
cpu_state
if
cpu_state
==
CPUState
.
MPS
:
torch
.
mps
.
empty_cache
()
elif
xpu_available
:
elif
is_intel_xpu
()
:
torch
.
xpu
.
empty_cache
()
elif
torch
.
cuda
.
is_available
():
if
is_nvidia
():
#This seems to make things worse on ROCm so I only do it for cuda
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment