Commit aeaeca10 authored by comfyanonymous's avatar comfyanonymous
Browse files

Small refactor of is_device_* functions.

parent 7f89cb48
...@@ -684,17 +684,20 @@ def mps_mode(): ...@@ -684,17 +684,20 @@ def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def is_device_cpu(device): def is_device_type(device, type):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if (device.type == 'cpu'): if (device.type == type):
return True return True
return False return False
def is_device_cpu(device):
return is_device_type(device, 'cpu')
def is_device_mps(device): def is_device_mps(device):
if hasattr(device, 'type'): return is_device_type(device, 'mps')
if (device.type == 'mps'):
return True def is_device_cuda(device):
return False return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment