nodes_differential_diffusion.py 3.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# code adapted from https://github.com/exx8/differential-diffusion

import torch
import inspect

class DifferentialDiffusion():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"model": ("MODEL", ),
                            }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply"
    CATEGORY = "_for_testing"
    INIT = False

    @classmethod
    def IS_CHANGED(s, *args, **kwargs):
        DifferentialDiffusion.INIT = s.INIT = True
        return ""

    def __init__(self) -> None:
        DifferentialDiffusion.INIT = False
        self.sigmas: torch.Tensor = None
        self.thresholds: torch.Tensor = None
        self.mask_i = None
        self.valid_sigmas = False
        self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]

    def apply(self, model):
        model = model.clone()
        model.model_options["denoise_mask_function"] = self.forward
        return (model,)
    
    def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
        self.__init__()
        self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
        self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
        if self.sigmas is None:
            self.sigmas = sigma[:1].repeat(2)
            self.sigmas[-1].zero_()
        self.sigmas_min = self.sigmas.min()
        self.sigmas_max = self.sigmas.max()
        self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
        self.thresholds_min_len = self.thresholds.shape[0] - 1
        if self.valid_sigmas:
            thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
            mask = denoise_mask.unsqueeze(0)
            mask = (mask >= thresholds).to(denoise_mask.dtype)
            self.mask_i = iter(mask)
    
    def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
        if self.sigmas is None or DifferentialDiffusion.INIT:
            self.init_sigmas(sigma, denoise_mask)
        if self.valid_sigmas:
            try:
                return next(self.mask_i)
            except StopIteration:
                self.valid_sigmas = False
        if self.thresholds_min_len > 1:
            nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
            if not self.thresholds_min_len > nearest_idx:
                nearest_idx = -2
            threshold = self.thresholds[nearest_idx]
        else:
            threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
        return (denoise_mask >= threshold).to(denoise_mask.dtype)

def get_sigmas_and_sampler(frame, target):
    found = frame.f_locals[target]
    if isinstance(found, torch.Tensor) and found[-1] < 0.1:
        return found, frame.f_code.co_name
    return False

def find_outer_instance(target: str, target_type=None, callback=None):
    frame = inspect.currentframe()
    i = 0
    while frame and i < 100:
        if target in frame.f_locals:
            if callback is not None:
                res = callback(frame, target)
                if res:
                    return res
            else:
                found = frame.f_locals[target]
                if isinstance(found, target_type):
                    return found
        frame = frame.f_back
        i += 1
    return None

    
NODE_CLASS_MAPPINGS = {
    "DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "DifferentialDiffusion": "Differential Diffusion",
}