nodes_model_merging.py 4.06 KB
Newer Older
1
2
3
4
5
import comfy.sd
import comfy.utils
import folder_paths
import json
import os
6
7
8
9
10
11
12
13
14
15
16

class ModelMergeSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"

17
    CATEGORY = "advanced/model_merging"
18
19
20

    def merge(self, model1, model2, ratio):
        m = model1.clone()
21
22
23
        kp = model2.get_key_patches("diffusion_model.")
        for k in kp:
            m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
24
25
        return (m, )

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class CLIPMergeSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "clip1": ("CLIP",),
                              "clip2": ("CLIP",),
                              "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("CLIP",)
    FUNCTION = "merge"

    CATEGORY = "advanced/model_merging"

    def merge(self, clip1, clip2, ratio):
        m = clip1.clone()
        kp = clip2.get_key_patches()
        for k in kp:
            if k.endswith(".position_ids") or k.endswith(".logit_scale"):
                continue
            m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
        return (m, )

47
48
49
50
51
52
53
54
55
56
57
58
class ModelMergeBlocks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"

59
    CATEGORY = "advanced/model_merging"
60
61
62

    def merge(self, model1, model2, **kwargs):
        m = model1.clone()
63
        kp = model2.get_key_patches("diffusion_model.")
64
65
        default_ratio = next(iter(kwargs.values()))

66
        for k in kp:
67
68
69
            ratio = default_ratio
            k_unet = k[len("diffusion_model."):]

70
            last_arg_size = 0
71
            for arg in kwargs:
72
                if k_unet.startswith(arg) and last_arg_size < len(arg):
73
                    ratio = kwargs[arg]
74
                    last_arg_size = len(arg)
75

76
            m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
77
78
        return (m, )

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class CheckpointSave:
    def __init__(self):
        self.output_dir = folder_paths.get_output_directory()

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "clip": ("CLIP",),
                              "vae": ("VAE",),
                              "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
                "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True

94
    CATEGORY = "advanced/model_merging"
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
        prompt_info = ""
        if prompt is not None:
            prompt_info = json.dumps(prompt)

        metadata = {"prompt": prompt_info}
        if extra_pnginfo is not None:
            for x in extra_pnginfo:
                metadata[x] = json.dumps(extra_pnginfo[x])

        output_checkpoint = f"{filename}_{counter:05}_.safetensors"
        output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

        comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
        return {}


114
115
NODE_CLASS_MAPPINGS = {
    "ModelMergeSimple": ModelMergeSimple,
116
117
    "ModelMergeBlocks": ModelMergeBlocks,
    "CheckpointSave": CheckpointSave,
118
    "CLIPMergeSimple": CLIPMergeSimple,
119
}