Commit 340177e6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Disable non blocking on mps.

parent 614b7e73
......@@ -553,15 +553,19 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
non_blocking = True
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=True)
return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=True)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment