"benchmark/mmlu/bench_other.py" did not exist on "f6d40df0ee1e1fc53db3edc04bf90575f221cf23"
nodes_perpneg.py 5.14 KB
Newer Older
Hari's avatar
Hari committed
1
2
import torch
import comfy.model_management
comfyanonymous's avatar
comfyanonymous committed
3
import comfy.sampler_helpers
Hari's avatar
Hari committed
4
5
import comfy.samplers
import comfy.utils
6
import node_helpers
Hari's avatar
Hari committed
7

8
9
10
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
    pos = noise_pred_pos - noise_pred_nocond
    neg = noise_pred_neg - noise_pred_nocond
Hari's avatar
Hari committed
11

12
13
14
15
16
17
    perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
    perp_neg = perp * neg_scale
    cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
    return cfg_result

#TODO: This node should be removed, it has been replaced with PerpNegGuider
Hari's avatar
Hari committed
18
19
20
21
class PerpNeg:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"model": ("MODEL", ),
22
                             "empty_conditioning": ("CONDITIONING", ),
comfyanonymous's avatar
comfyanonymous committed
23
                             "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
Hari's avatar
Hari committed
24
25
26
27
28
29
                            }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "_for_testing"

30
    def patch(self, model, empty_conditioning, neg_scale):
Hari's avatar
Hari committed
31
        m = model.clone()
comfyanonymous's avatar
comfyanonymous committed
32
        nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
Hari's avatar
Hari committed
33
34
35
36
37
38
39
40
41

        def cfg_function(args):
            model = args["model"]
            noise_pred_pos = args["cond_denoised"]
            noise_pred_neg = args["uncond_denoised"]
            cond_scale = args["cond_scale"]
            x = args["input"]
            sigma = args["sigma"]
            model_options = args["model_options"]
42
            nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
43

44
            (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
45

46
            cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
Hari's avatar
Hari committed
47
48
49
50
51
52
53
            return cfg_result

        m.set_model_sampler_cfg_function(cfg_function)

        return (m, )


54
55
56
57
58
59
60
61
62
63
class Guider_PerpNeg(comfy.samplers.CFGGuider):
    def set_conds(self, positive, negative, empty_negative_prompt):
        empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"})
        self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative})

    def set_cfg(self, cfg, neg_scale):
        self.cfg = cfg
        self.neg_scale = neg_scale

    def predict_noise(self, x, timestep, model_options={}, seed=None):
64
65
66
        # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
        # but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
        
67
68
69
70
        positive_cond = self.conds.get("positive", None)
        negative_cond = self.conds.get("negative", None)
        empty_cond = self.conds.get("empty_negative_prompt", None)

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        (noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
            comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
        cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)

        # normally this would be done in cfg_function, but we skipped 
        # that for efficiency: we can compute the noise predictions in
        # a single call to calc_cond_batch() (rather than two)
        # so we replicate the hook here
        for fn in model_options.get("sampler_post_cfg_function", []):
            args = {
                "denoised": cfg_result,
                "cond": positive_cond,
                "uncond": negative_cond,
                "model": self.inner_model,
                "uncond_denoised": noise_pred_neg,
                "cond_denoised": noise_pred_pos,
                "sigma": timestep,
                "model_options": model_options,
                "input": x,
                # not in the original call in samplers.py:cfg_function, but made available for future hooks
                "empty_cond": empty_cond,
                "empty_cond_denoised": noise_pred_empty,}
            cfg_result = fn(args)

        return cfg_result
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

class PerpNegGuider:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {"model": ("MODEL",),
                    "positive": ("CONDITIONING", ),
                    "negative": ("CONDITIONING", ),
                    "empty_conditioning": ("CONDITIONING", ),
                    "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
                    "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
                     }
                }

    RETURN_TYPES = ("GUIDER",)

    FUNCTION = "get_guider"
    CATEGORY = "_for_testing"

    def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
        guider = Guider_PerpNeg(model)
        guider.set_conds(positive, negative, empty_conditioning)
        guider.set_cfg(cfg, neg_scale)
        return (guider,)

Hari's avatar
Hari committed
121
122
NODE_CLASS_MAPPINGS = {
    "PerpNeg": PerpNeg,
123
    "PerpNegGuider": PerpNegGuider,
Hari's avatar
Hari committed
124
125
126
}

NODE_DISPLAY_NAME_MAPPINGS = {
127
    "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
Hari's avatar
Hari committed
128
}