utils.py 8.21 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
comfyanonymous's avatar
comfyanonymous committed
2
import math
BlenderNeko's avatar
BlenderNeko committed
3
import einops
comfyanonymous's avatar
comfyanonymous committed
4

5
def load_torch_file(ckpt, safe_load=False):
6
7
8
9
    if ckpt.lower().endswith(".safetensors"):
        import safetensors.torch
        sd = safetensors.torch.load_file(ckpt, device="cpu")
    else:
10
11
12
13
        if safe_load:
            pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
        else:
            pl_sd = torch.load(ckpt, map_location="cpu")
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
    return sd

def transformers_convert(sd, prefix_from, prefix_to, number):
    resblock_to_replace = {
        "ln_1": "layer_norm1",
        "ln_2": "layer_norm2",
        "mlp.c_fc": "mlp.fc1",
        "mlp.c_proj": "mlp.fc2",
        "attn.out_proj": "self_attn.out_proj",
    }

    for resblock in range(number):
        for x in resblock_to_replace:
            for y in ["weight", "bias"]:
                k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
                k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
                if k in sd:
                    sd[k_to] = sd.pop(k)

        for y in ["weight", "bias"]:
            k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
            if k_from in sd:
                weights = sd.pop(k_from)
                shape_from = weights.shape[0] // 3
                for x in range(3):
                    p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
                    k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
                    sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
    return sd

50
def bislerp(samples, width, height):
BlenderNeko's avatar
BlenderNeko committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def slerp(b1, b2, r):
        '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
        
        c = b1.shape[-1]

        #norms
        b1_norms = torch.norm(b1, dim=-1, keepdim=True)
        b2_norms = torch.norm(b2, dim=-1, keepdim=True)

        #normalize
        b1_normalized = b1 / b1_norms
        b2_normalized = b2 / b2_norms

        #zero when norms are zero
        b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
        b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0

        #slerp
        dot = (b1_normalized*b2_normalized).sum(1)
        omega = torch.acos(dot)
71
        so = torch.sin(omega)
BlenderNeko's avatar
BlenderNeko committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        #technically not mathematically correct, but more pleasing?
        res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
        res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)

        #edge cases for same or polar opposites
        res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] 
        res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
        return res
    
    def generate_bilinear_data(length_old, length_new):
        coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
        coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
        ratios = coords_1 - coords_1.floor()
        coords_1 = coords_1.to(torch.int64)
        
        coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
        coords_2[:,:,:,-1] -= 1
        coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
        coords_2 = coords_2.to(torch.int64)
        return ratios, coords_1, coords_2
    
    n,c,h,w = samples.shape
    h_new, w_new = (height, width)
    
    #linear h
    ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)

    coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w))
    coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w))
    ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w))

    pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c')
    pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c')
    ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')

    result = slerp(pass_1, pass_2, ratios)
    result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w)

    #linear w
    ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)

    coords_1 = coords_1.expand((n, c, h_new, -1))
    coords_2 = coords_2.expand((n, c, h_new, -1))
    ratios = ratios.expand((n, 1, h_new, -1))

    pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c')
    pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c')
    ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')

    result = slerp(pass_1, pass_2, ratios)
    result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new)
    return result
125

comfyanonymous's avatar
comfyanonymous committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def common_upscale(samples, width, height, upscale_method, crop):
        if crop == "center":
            old_width = samples.shape[3]
            old_height = samples.shape[2]
            old_aspect = old_width / old_height
            new_aspect = width / height
            x = 0
            y = 0
            if old_aspect > new_aspect:
                x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
            elif old_aspect < new_aspect:
                y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
            s = samples[:,:,y:old_height-y,x:old_width-x]
        else:
            s = samples
141
142
143
144
145

        if upscale_method == "bislerp":
            return bislerp(s, width, height)
        else:
            return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
146

pythongosssss's avatar
pythongosssss committed
147
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
comfyanonymous's avatar
comfyanonymous committed
148
    return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
pythongosssss's avatar
pythongosssss committed
149

150
@torch.inference_mode()
151
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
152
    output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
153
154
    for b in range(samples.shape[0]):
        s = samples[b:b+1]
155
156
        out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
        out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
157
158
159
160
161
162
        for y in range(0, s.shape[2], tile_y - overlap):
            for x in range(0, s.shape[3], tile_x - overlap):
                s_in = s[:,:,y:y+tile_y,x:x+tile_x]

                ps = function(s_in).cpu()
                mask = torch.ones_like(ps)
163
                feather = round(overlap * upscale_amount)
164
165
166
167
168
                for t in range(feather):
                        mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
                        mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
                        mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
                        mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
169
170
                out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
                out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
171
172
                if pbar is not None:
                    pbar.update(1)
173
174
175

        output[b:b+1] = out/out_div
    return output
176
177
178
179
180
181
182
183
184
185
186
187
188
189


PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
    global PROGRESS_BAR_HOOK
    PROGRESS_BAR_HOOK = function

class ProgressBar:
    def __init__(self, total):
        global PROGRESS_BAR_HOOK
        self.total = total
        self.current = 0
        self.hook = PROGRESS_BAR_HOOK

190
191
192
    def update_absolute(self, value, total=None):
        if total is not None:
            self.total = total
193
194
195
196
197
198
199
200
        if value > self.total:
            value = self.total
        self.current = value
        if self.hook is not None:
            self.hook(self.current, self.total)

    def update(self, value):
        self.update_absolute(self.current + value)