Commit 24129d78 authored by comfyanonymous's avatar comfyanonymous
Browse files

Speed up SDXL on 16xx series with fp16 weights and manual cast.

parent 98b80ad1
...@@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0): ...@@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0):
return torch.float8_e4m3fn return torch.float8_e4m3fn
if args.fp8_e5m2_unet: if args.fp8_e5m2_unet:
return torch.float8_e5m2 return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params): if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
return torch.float16 return torch.float16
return torch.float32 return torch.float32
...@@ -696,7 +696,7 @@ def is_device_mps(device): ...@@ -696,7 +696,7 @@ def is_device_mps(device):
return True return True
return False return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_enabled
if device is not None: if device is not None:
...@@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): ...@@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if x in props.name.lower(): if x in props.name.lower():
fp16_works = True fp16_works = True
if fp16_works: if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment