Commit b0aab1e4 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add an option --fp16-unet to force using fp16 for the unet.

parent ba07cb74
...@@ -57,6 +57,7 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") ...@@ -57,6 +57,7 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group() fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
......
...@@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype): ...@@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0): def unet_dtype(device=None, model_params=0):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet:
return torch.float16
if args.fp8_e4m3fn_unet: if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn return torch.float8_e4m3fn
if args.fp8_e5m2_unet: if args.fp8_e5m2_unet:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment