Commit 929e266f authored by comfyanonymous's avatar comfyanonymous
Browse files

Manual cast for bf16 on older GPUs.

parent 6c875d84
...@@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor ...@@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes: if torch.float16 in supported_dtypes:
return torch.float16 return torch.float16
if should_use_bf16(device): if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes: if torch.bfloat16 in supported_dtypes:
return torch.bfloat16 return torch.bfloat16
return torch.float32 return torch.float32
...@@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma ...@@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True return True
def should_use_bf16(device=None): def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return False
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False
if is_intel_xpu(): if is_intel_xpu():
return True return True
...@@ -785,6 +799,13 @@ def should_use_bf16(device=None): ...@@ -785,6 +799,13 @@ def should_use_bf16(device=None):
if props.major >= 8: if props.major >= 8:
return True return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
return False return False
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment