"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c2381e10238e16524c5a89579efb82c601e0f6a4"
Commit 2c096e42 authored by Simon Lui's avatar Simon Lui Committed by comfyanonymous
Browse files

Add ipex optimize and other enhancements for Intel GPUs based on recent memory changes.

parent 8ee04736
......@@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
......
......@@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
mem_total_torch = mem_reserved
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
......@@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY:
print("Disabling smart memory management")
def get_torch_device_name(device):
global xpu_available
if hasattr(device, 'type'):
if device.type == "cuda":
try:
......@@ -217,6 +220,8 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else:
return "{}".format(device.type)
elif xpu_available:
return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
......@@ -244,6 +249,7 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
global xpu_available
patch_model_to = None
if lowvram_model_memory == 0:
patch_model_to = self.device
......@@ -264,6 +270,10 @@ class LoadedModel:
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
if xpu_available and not args.disable_ipex_optimize:
self.real_model.training = False
self.real_model = torch.xpu.optimize(self.real_model, inplace=True)
return self.real_model
def model_unload(self):
......@@ -500,8 +510,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
......@@ -573,10 +587,10 @@ def should_use_fp16(device=None, model_params=0):
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available:
if cpu_mode() or mps_mode():
return False #TODO ?
if torch.cuda.is_bf16_supported():
if torch.cuda.is_bf16_supported() or xpu_available:
return True
props = torch.cuda.get_device_properties("cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment