controlnet.py 22.9 KB
Newer Older
1
2
import torch
import math
3
import os
4
import logging
5
6
7
import comfy.utils
import comfy.model_management
import comfy.model_detection
8
import comfy.model_patcher
9
import comfy.ops
10
11
12

import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
13
import comfy.ldm.cascade.controlnet
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def broadcast_image_to(tensor, target_batch_size, batched_number):
    current_batch_size = tensor.shape[0]
    #print(current_batch_size, target_batch_size)
    if current_batch_size == 1:
        return tensor

    per_batch = target_batch_size // batched_number
    tensor = tensor[:per_batch]

    if per_batch > tensor.shape[0]:
        tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)

    current_batch_size = tensor.shape[0]
    if current_batch_size == target_batch_size:
        return tensor
    else:
        return torch.cat([tensor] * batched_number, dim=0)

class ControlBase:
    def __init__(self, device=None):
        self.cond_hint_original = None
        self.cond_hint = None
        self.strength = 1.0
39
        self.timestep_percent_range = (0.0, 1.0)
40
        self.global_average_pooling = False
41
        self.timestep_range = None
42
        self.compression_ratio = 8
43
        self.upscale_algorithm = 'nearest-exact'
44
45
46
47
48
49

        if device is None:
            device = comfy.model_management.get_torch_device()
        self.device = device
        self.previous_controlnet = None

50
    def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        self.cond_hint_original = cond_hint
        self.strength = strength
        self.timestep_percent_range = timestep_percent_range
        return self

    def pre_run(self, model, percent_to_timestep_function):
        self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
        if self.previous_controlnet is not None:
            self.previous_controlnet.pre_run(model, percent_to_timestep_function)

    def set_previous_controlnet(self, controlnet):
        self.previous_controlnet = controlnet
        return self

    def cleanup(self):
        if self.previous_controlnet is not None:
            self.previous_controlnet.cleanup()
        if self.cond_hint is not None:
            del self.cond_hint
            self.cond_hint = None
        self.timestep_range = None

    def get_models(self):
        out = []
        if self.previous_controlnet is not None:
            out += self.previous_controlnet.get_models()
        return out

    def copy_to(self, c):
        c.cond_hint_original = self.cond_hint_original
        c.strength = self.strength
        c.timestep_percent_range = self.timestep_percent_range
83
        c.global_average_pooling = self.global_average_pooling
84
        c.compression_ratio = self.compression_ratio
85
        c.upscale_algorithm = self.upscale_algorithm
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    def inference_memory_requirements(self, dtype):
        if self.previous_controlnet is not None:
            return self.previous_controlnet.inference_memory_requirements(dtype)
        return 0

    def control_merge(self, control_input, control_output, control_prev, output_dtype):
        out = {'input':[], 'middle':[], 'output': []}

        if control_input is not None:
            for i in range(len(control_input)):
                key = 'input'
                x = control_input[i]
                if x is not None:
                    x *= self.strength
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)
                out[key].insert(0, x)

        if control_output is not None:
            for i in range(len(control_output)):
                if i == (len(control_output) - 1):
                    key = 'middle'
                    index = 0
                else:
                    key = 'output'
                    index = i
                x = control_output[i]
                if x is not None:
                    if self.global_average_pooling:
                        x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])

                    x *= self.strength
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)

                out[key].append(x)
        if control_prev is not None:
            for x in ['input', 'middle', 'output']:
                o = out[x]
                for i in range(len(control_prev[x])):
                    prev_val = control_prev[x][i]
                    if i >= len(o):
                        o.append(prev_val)
                    elif prev_val is not None:
                        if o[i] is None:
                            o[i] = prev_val
                        else:
134
135
136
137
                            if o[i].shape[0] < prev_val.shape[0]:
                                o[i] = prev_val + o[i]
                            else:
                                o[i] += prev_val
138
139
140
        return out

class ControlNet(ControlBase):
141
    def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
142
143
        super().__init__(device)
        self.control_model = control_model
144
        self.load_device = load_device
145
146
147
        if control_model is not None:
            self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

148
        self.global_average_pooling = global_average_pooling
149
        self.model_sampling_current = None
150
        self.manual_cast_dtype = manual_cast_dtype
151
152
153
154
155
156
157
158
159
160
161

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
162
                    return None
163

164
        dtype = self.control_model.dtype
165
166
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
167

168
        output_dtype = x_noisy.dtype
169
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
170
171
172
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = None
173
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
174
175
176
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

177
        context = cond.get('crossattn_controlnet', cond['c_crossattn'])
178
        y = cond.get('y', None)
179
        if y is not None:
180
            y = y.to(dtype)
181
182
183
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

184
        control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
185
186
187
        return self.control_merge(None, control, control_prev, output_dtype)

    def copy(self):
188
189
190
        c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        c.control_model = self.control_model
        c.control_model_wrapped = self.control_model_wrapped
191
192
193
194
195
196
197
198
        self.copy_to(c)
        return c

    def get_models(self):
        out = super().get_models()
        out.append(self.control_model_wrapped)
        return out

199
200
201
202
203
204
205
206
    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        self.model_sampling_current = model.model_sampling

    def cleanup(self):
        self.model_sampling_current = None
        super().cleanup()

207
class ControlLoraOps:
comfyanonymous's avatar
comfyanonymous committed
208
    class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
209
210
211
212
213
214
215
216
217
218
219
220
        def __init__(self, in_features: int, out_features: int, bias: bool = True,
                    device=None, dtype=None) -> None:
            factory_kwargs = {'device': device, 'dtype': dtype}
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = None
            self.up = None
            self.down = None
            self.bias = None

        def forward(self, input):
221
            weight, bias = comfy.ops.cast_bias_weight(self, input)
222
            if self.up is not None:
223
                return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
224
            else:
225
                return torch.nn.functional.linear(input, weight, bias)
226

comfyanonymous's avatar
comfyanonymous committed
227
    class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            padding_mode='zeros',
            device=None,
            dtype=None
        ):
            super().__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
            self.transposed = False
            self.output_padding = 0
            self.groups = groups
            self.padding_mode = padding_mode

            self.weight = None
            self.bias = None
            self.up = None
            self.down = None


        def forward(self, input):
261
            weight, bias = comfy.ops.cast_bias_weight(self, input)
262
            if self.up is not None:
263
                return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
264
            else:
265
                return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
266

267
268
269
270
271
272
273
274
275
276
277
278

class ControlLora(ControlNet):
    def __init__(self, control_weights, global_average_pooling=False, device=None):
        ControlBase.__init__(self, device)
        self.control_weights = control_weights
        self.global_average_pooling = global_average_pooling

    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        controlnet_config = model.model_config.unet_config.copy()
        controlnet_config.pop("out_channels")
        controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
279
280
281
282
283
284
285
286
287
288
        self.manual_cast_dtype = model.manual_cast_dtype
        dtype = model.get_dtype()
        if self.manual_cast_dtype is None:
            class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
                pass
        else:
            class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
                pass
            dtype = self.manual_cast_dtype

comfyanonymous's avatar
comfyanonymous committed
289
        controlnet_config["operations"] = control_lora_ops
290
        controlnet_config["dtype"] = dtype
291
292
293
294
295
296
297
        self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
        self.control_model.to(comfy.model_management.get_torch_device())
        diffusion_model = model.diffusion_model
        sd = diffusion_model.state_dict()
        cm = self.control_model.state_dict()

        for k in sd:
298
            weight = sd[k]
299
            try:
300
                comfy.utils.set_attr_param(self.control_model, k, weight)
301
302
303
304
305
            except:
                pass

        for k in self.control_weights:
            if k not in {"lora_controlnet"}:
306
                comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

    def copy(self):
        c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
        self.copy_to(c)
        return c

    def cleanup(self):
        del self.control_model
        self.control_model = None
        super().cleanup()

    def get_models(self):
        out = ControlBase.get_models(self)
        return out

    def inference_memory_requirements(self, dtype):
        return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

def load_controlnet(ckpt_path, model=None):
    controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
    if "lora_controlnet" in controlnet_data:
        return ControlLora(controlnet_data)

    controlnet_config = None
comfyanonymous's avatar
comfyanonymous committed
331
332
    supported_inference_dtypes = None

333
    if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
comfyanonymous's avatar
comfyanonymous committed
334
        controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
        diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
        diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                k_in = "controlnet_down_blocks.{}{}".format(count, s)
                k_out = "zero_convs.{}.0{}".format(count, s)
                if k_in not in controlnet_data:
                    loop = False
                    break
                diffusers_keys[k_in] = k_out
            count += 1

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                if count == 0:
                    k_in = "controlnet_cond_embedding.conv_in{}".format(s)
                else:
                    k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
                k_out = "input_hint_block.{}{}".format(count * 2, s)
                if k_in not in controlnet_data:
                    k_in = "controlnet_cond_embedding.conv_out{}".format(s)
                    loop = False
                diffusers_keys[k_in] = k_out
            count += 1

        new_sd = {}
        for k in diffusers_keys:
            if k in controlnet_data:
                new_sd[diffusers_keys[k]] = controlnet_data.pop(k)

        leftover_keys = controlnet_data.keys()
        if len(leftover_keys) > 0:
375
            logging.warning("leftover keys: {}".format(leftover_keys))
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        controlnet_data = new_sd

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        net = load_t2i_adapter(controlnet_data)
        if net is None:
390
            logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
391
392
393
        return net

    if controlnet_config is None:
comfyanonymous's avatar
comfyanonymous committed
394
395
396
397
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = model_config.supported_inference_dtypes
        controlnet_config = model_config.unet_config

398
    load_device = comfy.model_management.get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
399
400
401
402
403
    if supported_inference_dtypes is None:
        unet_dtype = comfy.model_management.unet_dtype()
    else:
        unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

404
405
406
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        controlnet_config["operations"] = comfy.ops.manual_cast
comfyanonymous's avatar
comfyanonymous committed
407
    controlnet_config["dtype"] = unet_dtype
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
425
                logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
426
427
428
429
430
431
432
433

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
434
435
436
437
438

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
439
        logging.debug("unexpected controlnet keys: {}".format(unexpected))
440
441

    global_average_pooling = False
442
443
    filename = os.path.splitext(ckpt_path)[0]
    if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
444
445
        global_average_pooling = True

446
    control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
447
448
449
    return control

class T2IAdapter(ControlBase):
450
    def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
451
452
453
454
        super().__init__(device)
        self.t2i_model = t2i_model
        self.channels_in = channels_in
        self.control_input = None
455
        self.compression_ratio = compression_ratio
456
        self.upscale_algorithm = upscale_algorithm
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

    def scale_image_to(self, width, height):
        unshuffle_amount = self.t2i_model.unshuffle_amount
        width = math.ceil(width / unshuffle_amount) * unshuffle_amount
        height = math.ceil(height / unshuffle_amount) * unshuffle_amount
        return width, height

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
474
                    return None
475

476
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
477
478
479
480
            if self.cond_hint is not None:
                del self.cond_hint
            self.control_input = None
            self.cond_hint = None
481
            width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
482
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
            if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
                self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
        if self.control_input is None:
            self.t2i_model.to(x_noisy.dtype)
            self.t2i_model.to(self.device)
            self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
            self.t2i_model.cpu()

        control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
        mid = None
        if self.t2i_model.xl == True:
            mid = control_input[-1:]
            control_input = control_input[:-1]
        return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)

    def copy(self):
501
        c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
502
503
504
505
        self.copy_to(c)
        return c

def load_t2i_adapter(t2i_data):
506
    compression_ratio = 8
507
    upscale_algorithm = 'nearest-exact'
508

509
    if 'adapter' in t2i_data:
510
        t2i_data = t2i_data['adapter']
511
512
513
514
515
516
517
518
519
520
    if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
        prefix_replace = {}
        for i in range(4):
            for j in range(2):
                prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
            prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
        prefix_replace["adapter."] = ""
        t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
    keys = t2i_data.keys()

521
522
523
524
525
526
527
528
529
530
531
532
    if "body.0.in_conv.weight" in keys:
        cin = t2i_data['body.0.in_conv.weight'].shape[1]
        model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
    elif 'conv_in.weight' in keys:
        cin = t2i_data['conv_in.weight'].shape[1]
        channel = t2i_data['conv_in.weight'].shape[0]
        ksize = t2i_data['body.0.block2.weight'].shape[2]
        use_conv = False
        down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
        if len(down_opts) > 0:
            use_conv = True
        xl = False
533
        if cin == 256 or cin == 768:
534
535
            xl = True
        model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
536
537
538
    elif "backbone.0.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 32
539
        upscale_algorithm = 'bilinear'
540
541
542
543
    elif "backbone.10.blocks.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 1
        upscale_algorithm = 'nearest-exact'
544
545
    else:
        return None
546

547
548
    missing, unexpected = model_ad.load_state_dict(t2i_data)
    if len(missing) > 0:
549
        logging.warning("t2i missing {}".format(missing))
550
551

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
552
        logging.debug("t2i unexpected {}".format(unexpected))
553

554
    return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)