"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "8ef9732931bb298c8d5ffd66efcc4d79ad2cdbe2"
Commit 0cf4e869 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add some command line arguments to store text encoder weights in fp8.

Pytorch supports two variants of fp8:
--fp8_e4m3fn-text-enc (the one that seems to give better results)
--fp8_e5m2-text-enc
parent 107e78b1
...@@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in ...@@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
......
...@@ -482,6 +482,21 @@ def text_encoder_device(): ...@@ -482,6 +482,21 @@ def text_encoder_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device(): def vae_device():
return get_torch_device() return get_torch_device()
......
...@@ -95,10 +95,7 @@ class CLIP: ...@@ -95,10 +95,7 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False): params['dtype'] = model_management.text_encoder_dtype(load_device)
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment