"deploy/vscode:/vscode.git/clone" did not exist on "ee05c913cb7a81d003f8987dc5e08bbd62d06ae0"
Commit 2c096e42 authored by Simon Lui's avatar Simon Lui Committed by comfyanonymous
Browse files

Add ipex optimize and other enhancements for Intel GPUs based on recent memory changes.

parent 8ee04736
...@@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in ...@@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"
Auto = "auto" Auto = "auto"
......
...@@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False): ...@@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 #TODO mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif xpu_available: elif xpu_available:
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total mem_total_torch = mem_reserved
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
...@@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY: ...@@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY:
print("Disabling smart memory management") print("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
global xpu_available
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
try: try:
...@@ -217,6 +220,8 @@ def get_torch_device_name(device): ...@@ -217,6 +220,8 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif xpu_available:
return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
...@@ -244,6 +249,7 @@ class LoadedModel: ...@@ -244,6 +249,7 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
global xpu_available
patch_model_to = None patch_model_to = None
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
patch_model_to = self.device patch_model_to = self.device
...@@ -264,6 +270,10 @@ class LoadedModel: ...@@ -264,6 +270,10 @@ class LoadedModel:
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True self.model_accelerated = True
if xpu_available and not args.disable_ipex_optimize:
self.real_model.training = False
self.real_model = torch.xpu.optimize(self.real_model, inplace=True)
return self.real_model return self.real_model
def model_unload(self): def model_unload(self):
...@@ -500,8 +510,12 @@ def get_free_memory(dev=None, torch_free_too=False): ...@@ -500,8 +510,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif xpu_available: elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) stats = torch.xpu.memory_stats(dev)
mem_free_torch = mem_free_total mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
...@@ -573,10 +587,10 @@ def should_use_fp16(device=None, model_params=0): ...@@ -573,10 +587,10 @@ def should_use_fp16(device=None, model_params=0):
if directml_enabled: if directml_enabled:
return False return False
if cpu_mode() or mps_mode() or xpu_available: if cpu_mode() or mps_mode():
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported() or xpu_available:
return True return True
props = torch.cuda.get_device_properties("cuda") props = torch.cuda.get_device_properties("cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment