"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "df7db0e0279f58d9f2f3f33ddb60bb238b6d0dc8"
Commit 4a0c4ce4 authored by Simon Lui's avatar Simon Lui
Browse files

Some fixes to generalize CUDA specific functionality to Intel or other GPUs.

parent 62efc78a
...@@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module): ...@@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module):
break break
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
if first_op_done == False: if first_op_done == False:
torch.cuda.empty_cache() model_management.soft_empty_cache()
torch.cuda.ipc_collect()
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
print("out of memory error, emptying cache and trying again") print("out of memory error, emptying cache and trying again")
......
...@@ -15,6 +15,7 @@ import torch.nn as nn ...@@ -15,6 +15,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import repeat from einops import repeat
from comfy import model_management
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
import comfy.ops import comfy.ops
...@@ -139,13 +140,22 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -139,13 +140,22 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), \ if model_management.is_nvidia():
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): with torch.enable_grad(), \
# Fixes a bug where the first op in run_function modifies the torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Tensor storage in place, which is not allowed for detach()'d # Fixes a bug where the first op in run_function modifies the
# Tensors. # Tensor storage in place, which is not allowed for detach()'d
shallow_copies = [x.view_as(x) for x in ctx.input_tensors] # Tensors.
output_tensors = ctx.run_function(*shallow_copies) shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
elif model_management.is_intel_xpu():
with torch.enable_grad(), \
torch.xpu.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad( input_grads = torch.autograd.grad(
output_tensors, output_tensors,
ctx.input_tensors + ctx.input_params, ctx.input_tensors + ctx.input_params,
......
...@@ -58,8 +58,15 @@ except: ...@@ -58,8 +58,15 @@ except:
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
def get_torch_device(): def is_intel_xpu():
global cpu_state
global xpu_available global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return False
def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
if directml_enabled: if directml_enabled:
...@@ -70,13 +77,12 @@ def get_torch_device(): ...@@ -70,13 +77,12 @@ def get_torch_device():
if cpu_state == CPUState.CPU: if cpu_state == CPUState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
if xpu_available: if is_intel_xpu():
return torch.device("xpu") return torch.device("xpu")
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
...@@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False): ...@@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if directml_enabled: if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif xpu_available: elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
...@@ -146,11 +152,11 @@ def is_nvidia(): ...@@ -146,11 +152,11 @@ def is_nvidia():
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
if torch.version.cuda: if torch.version.cuda:
return True return True
return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32 VAE_DTYPE = torch.float32
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
...@@ -162,6 +168,9 @@ try: ...@@ -162,6 +168,9 @@ try:
except: except:
pass pass
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
if args.fp16_vae: if args.fp16_vae:
VAE_DTYPE = torch.float16 VAE_DTYPE = torch.float16
elif args.bf16_vae: elif args.bf16_vae:
...@@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY: ...@@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY:
print("Disabling smart memory management") print("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
global xpu_available
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
try: try:
...@@ -230,7 +238,7 @@ def get_torch_device_name(device): ...@@ -230,7 +238,7 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif xpu_available: elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device)) return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
...@@ -260,7 +268,6 @@ class LoadedModel: ...@@ -260,7 +268,6 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
global xpu_available
patch_model_to = None patch_model_to = None
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
patch_model_to = self.device patch_model_to = self.device
...@@ -281,7 +288,7 @@ class LoadedModel: ...@@ -281,7 +288,7 @@ class LoadedModel:
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True self.model_accelerated = True
if xpu_available and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
return self.real_model return self.real_model
...@@ -471,12 +478,11 @@ def get_autocast_device(dev): ...@@ -471,12 +478,11 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available
global directml_enabled global directml_enabled
global cpu_state global cpu_state
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
return False return False
if xpu_available: if is_intel_xpu():
return False return False
if directml_enabled: if directml_enabled:
return False return False
...@@ -503,7 +509,6 @@ def pytorch_attention_flash_attention(): ...@@ -503,7 +509,6 @@ def pytorch_attention_flash_attention():
return False return False
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
...@@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False): ...@@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if directml_enabled: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif xpu_available: elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current'] mem_allocated = stats['allocated_bytes.all.current']
...@@ -577,7 +582,6 @@ def is_device_mps(device): ...@@ -577,7 +582,6 @@ def is_device_mps(device):
return False return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True): def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
global xpu_available
global directml_enabled global directml_enabled
if device is not None: if device is not None:
...@@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): ...@@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if cpu_mode() or mps_mode(): if cpu_mode() or mps_mode():
return False #TODO ? return False #TODO ?
if xpu_available: if is_intel_xpu():
return True return True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
...@@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): ...@@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return True return True
def soft_empty_cache(): def soft_empty_cache():
global xpu_available
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
elif xpu_available: elif is_intel_xpu():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment