"test/onnx/parse/randomnormallike_test.cpp" did not exist on "d2532d0eee52c8e12d05990ce6a7c8bec6480df7"
Commit 929e266f authored by comfyanonymous's avatar comfyanonymous
Browse files

Manual cast for bf16 on older GPUs.

parent 6c875d84
......@@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device):
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
......@@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
def should_use_bf16(device=None):
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return False
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False
if is_intel_xpu():
return True
......@@ -785,6 +799,13 @@ def should_use_bf16(device=None):
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
return False
def soft_empty_cache(force=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment