Commit 0e836d52 authored by comfyanonymous's avatar comfyanonymous
Browse files

use half() on fp16 models loaded with config.

parent 986dd820
...@@ -733,6 +733,12 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e ...@@ -733,6 +733,12 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
scale_factor = model_config_params['scale_factor'] scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config'] vae_config = model_config_params['first_stage_config']
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]:
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
clip = None clip = None
vae = None vae = None
...@@ -754,6 +760,10 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e ...@@ -754,6 +760,10 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
model = instantiate_from_config(config["model"]) model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path) sd = load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment