nodes_model_merging.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16


class ModelMergeSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"

    CATEGORY = "_for_testing/model_merging"

    def merge(self, model1, model2, ratio):
        m = model1.clone()
17
        sd = model2.model_state_dict("diffusion_model.")
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        for k in sd:
            m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
        return (m, )

class ModelMergeBlocks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"

    CATEGORY = "_for_testing/model_merging"

    def merge(self, model1, model2, **kwargs):
        m = model1.clone()
38
        sd = model2.model_state_dict("diffusion_model.")
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        default_ratio = next(iter(kwargs.values()))

        for k in sd:
            ratio = default_ratio
            k_unet = k[len("diffusion_model."):]

            for arg in kwargs:
                if k_unet.startswith(arg):
                    ratio = kwargs[arg]

            m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
        return (m, )

NODE_CLASS_MAPPINGS = {
    "ModelMergeSimple": ModelMergeSimple,
    "ModelMergeBlocks": ModelMergeBlocks
}