utils.py 8.44 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
comfyanonymous's avatar
comfyanonymous committed
2
import math
BlenderNeko's avatar
BlenderNeko committed
3
import einops
comfyanonymous's avatar
comfyanonymous committed
4

5
def load_torch_file(ckpt, safe_load=False):
6
7
8
9
    if ckpt.lower().endswith(".safetensors"):
        import safetensors.torch
        sd = safetensors.torch.load_file(ckpt, device="cpu")
    else:
10
11
12
13
        if safe_load:
            if not 'weights_only' in torch.load.__code__.co_varnames:
                print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
                safe_load = False
14
15
16
17
        if safe_load:
            pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
        else:
            pl_sd = torch.load(ckpt, map_location="cpu")
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
    return sd

def transformers_convert(sd, prefix_from, prefix_to, number):
    resblock_to_replace = {
        "ln_1": "layer_norm1",
        "ln_2": "layer_norm2",
        "mlp.c_fc": "mlp.fc1",
        "mlp.c_proj": "mlp.fc2",
        "attn.out_proj": "self_attn.out_proj",
    }

    for resblock in range(number):
        for x in resblock_to_replace:
            for y in ["weight", "bias"]:
                k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
                k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
                if k in sd:
                    sd[k_to] = sd.pop(k)

        for y in ["weight", "bias"]:
            k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
            if k_from in sd:
                weights = sd.pop(k_from)
                shape_from = weights.shape[0] // 3
                for x in range(3):
                    p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
                    k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
                    sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
    return sd

54
def bislerp(samples, width, height):
BlenderNeko's avatar
BlenderNeko committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def slerp(b1, b2, r):
        '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
        
        c = b1.shape[-1]

        #norms
        b1_norms = torch.norm(b1, dim=-1, keepdim=True)
        b2_norms = torch.norm(b2, dim=-1, keepdim=True)

        #normalize
        b1_normalized = b1 / b1_norms
        b2_normalized = b2 / b2_norms

        #zero when norms are zero
        b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
        b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0

        #slerp
        dot = (b1_normalized*b2_normalized).sum(1)
        omega = torch.acos(dot)
75
        so = torch.sin(omega)
BlenderNeko's avatar
BlenderNeko committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

        #technically not mathematically correct, but more pleasing?
        res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
        res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)

        #edge cases for same or polar opposites
        res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] 
        res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
        return res
    
    def generate_bilinear_data(length_old, length_new):
        coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
        coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
        ratios = coords_1 - coords_1.floor()
        coords_1 = coords_1.to(torch.int64)
        
        coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
        coords_2[:,:,:,-1] -= 1
        coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
        coords_2 = coords_2.to(torch.int64)
        return ratios, coords_1, coords_2
    
    n,c,h,w = samples.shape
    h_new, w_new = (height, width)
    
    #linear h
    ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)

    coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w))
    coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w))
    ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w))

    pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c')
    pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c')
    ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')

    result = slerp(pass_1, pass_2, ratios)
    result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w)

    #linear w
    ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)

    coords_1 = coords_1.expand((n, c, h_new, -1))
    coords_2 = coords_2.expand((n, c, h_new, -1))
    ratios = ratios.expand((n, 1, h_new, -1))

    pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c')
    pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c')
    ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')

    result = slerp(pass_1, pass_2, ratios)
    result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new)
    return result
129

comfyanonymous's avatar
comfyanonymous committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def common_upscale(samples, width, height, upscale_method, crop):
        if crop == "center":
            old_width = samples.shape[3]
            old_height = samples.shape[2]
            old_aspect = old_width / old_height
            new_aspect = width / height
            x = 0
            y = 0
            if old_aspect > new_aspect:
                x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
            elif old_aspect < new_aspect:
                y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
            s = samples[:,:,y:old_height-y,x:old_width-x]
        else:
            s = samples
145
146
147
148
149

        if upscale_method == "bislerp":
            return bislerp(s, width, height)
        else:
            return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
150

pythongosssss's avatar
pythongosssss committed
151
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
comfyanonymous's avatar
comfyanonymous committed
152
    return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
pythongosssss's avatar
pythongosssss committed
153

154
@torch.inference_mode()
155
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
156
    output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
157
158
    for b in range(samples.shape[0]):
        s = samples[b:b+1]
159
160
        out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
        out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
161
162
163
164
165
166
        for y in range(0, s.shape[2], tile_y - overlap):
            for x in range(0, s.shape[3], tile_x - overlap):
                s_in = s[:,:,y:y+tile_y,x:x+tile_x]

                ps = function(s_in).cpu()
                mask = torch.ones_like(ps)
167
                feather = round(overlap * upscale_amount)
168
169
170
171
172
                for t in range(feather):
                        mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
                        mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
                        mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
                        mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
173
174
                out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
                out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
175
176
                if pbar is not None:
                    pbar.update(1)
177
178
179

        output[b:b+1] = out/out_div
    return output
180
181
182
183
184
185
186
187
188
189
190
191
192
193


PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
    global PROGRESS_BAR_HOOK
    PROGRESS_BAR_HOOK = function

class ProgressBar:
    def __init__(self, total):
        global PROGRESS_BAR_HOOK
        self.total = total
        self.current = 0
        self.hook = PROGRESS_BAR_HOOK

194
195
196
    def update_absolute(self, value, total=None):
        if total is not None:
            self.total = total
197
198
199
200
201
202
203
204
        if value > self.total:
            value = self.total
        self.current = value
        if self.hook is not None:
            self.hook(self.current, self.total)

    def update(self, value):
        self.update_absolute(self.current + value)