"src/vscode:/vscode.git/clone" did not exist on "18fc40c169a82da2fca188b5d0083bda6ac044ab"
Commit cc44ade7 authored by comfyanonymous's avatar comfyanonymous
Browse files

Always shift text encoder to GPU when the device supports fp16.

parent a6ef08a4
...@@ -432,8 +432,7 @@ def text_encoder_device(): ...@@ -432,8 +432,7 @@ def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU if should_use_fp16(prioritize_performance=False):
if should_use_fp16() or torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
...@@ -569,7 +568,7 @@ def is_device_mps(device): ...@@ -569,7 +568,7 @@ def is_device_mps(device):
return True return True
return False return False
def should_use_fp16(device=None, model_params=0): def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
...@@ -614,7 +613,7 @@ def should_use_fp16(device=None, model_params=0): ...@@ -614,7 +613,7 @@ def should_use_fp16(device=None, model_params=0):
if fp16_works: if fp16_works:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
if props.major < 7: if props.major < 7:
......
...@@ -545,7 +545,7 @@ class CLIP: ...@@ -545,7 +545,7 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device params['device'] = load_device
if model_management.should_use_fp16(load_device): if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16 params['dtype'] = torch.float16
else: else:
params['dtype'] = torch.float32 params['dtype'] = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment