diffusers_load.py 3.78 KB
Newer Older
1
2
3
4
5
import json
import os
import yaml

import folder_paths
6
from comfy.sd import load_checkpoint
7
8
9
10
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
11
12
from . import diffusers_convert

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
    diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
    diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))

    # magic
    v2 = diffusers_unet_conf["sample_size"] == 96
    if 'prediction_type' in diffusers_scheduler_conf:
        v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'

    if v2:
        if v_pred:
            config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
        else:
            config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
    else:
        config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')

    with open(config_path, 'r') as stream:
        config = yaml.safe_load(stream)

    model_config_params = config['model']['params']
    clip_config = model_config_params['cond_stage_config']
    scale_factor = model_config_params['scale_factor']
    vae_config = model_config_params['first_stage_config']
    vae_config['scale_factor'] = scale_factor
    model_config_params["unet_config"]["params"]["use_fp16"] = fp16

    unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
    vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
    text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")

    # Load models from safetensors if it exists, if it doesn't pytorch
    if osp.exists(unet_path):
        unet_state_dict = load_file(unet_path, device="cpu")
    else:
        unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
        unet_state_dict = torch.load(unet_path, map_location="cpu")

    if osp.exists(vae_path):
        vae_state_dict = load_file(vae_path, device="cpu")
    else:
        vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
        vae_state_dict = torch.load(vae_path, map_location="cpu")

    if osp.exists(text_enc_path):
        text_enc_dict = load_file(text_enc_path, device="cpu")
    else:
        text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
        text_enc_dict = torch.load(text_enc_path, map_location="cpu")

    # Convert the UNet model
    unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
    unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}

    # Convert the VAE model
    vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
    vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}

    # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
    is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict

    if is_v20_model:
        # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
        text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
        text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
        text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
    else:
        text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
        text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}

    # Put together new checkpoint
    sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}

comfyanonymous's avatar
comfyanonymous committed
87
    return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)