Commit 51d54775 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add key to indicate checkpoint is v_prediction when saving.

parent ff6b047a
......@@ -99,6 +99,10 @@ class BaseModel(torch.nn.Module):
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment