Unverified Commit a56d02ef authored by Simon Lui's avatar Simon Lui Committed by GitHub
Browse files

Change torch.xpu to ipex.optimize, xpu device initialization and remove...

Change torch.xpu to ipex.optimize, xpu device initialization and remove workaround for text node issue from older IPEX. (#3388)
parent f81a6fad
...@@ -83,7 +83,7 @@ def get_torch_device(): ...@@ -83,7 +83,7 @@ def get_torch_device():
return torch.device("cpu") return torch.device("cpu")
else: else:
if is_intel_xpu(): if is_intel_xpu():
return torch.device("xpu") return torch.device("xpu", torch.xpu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
...@@ -304,7 +304,7 @@ class LoadedModel: ...@@ -304,7 +304,7 @@ class LoadedModel:
raise e raise e
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
...@@ -552,8 +552,6 @@ def text_encoder_device(): ...@@ -552,8 +552,6 @@ def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
if is_intel_xpu():
return torch.device("cpu")
if should_use_fp16(prioritize_performance=False): if should_use_fp16(prioritize_performance=False):
return get_torch_device() return get_torch_device()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment