nodes.py 32.7 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
import torch

import os
import sys
import json
6
import hashlib
comfyanonymous's avatar
comfyanonymous committed
7
import copy
8
import traceback
comfyanonymous's avatar
comfyanonymous committed
9
10
11
12
13

from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np

14
sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
comfyanonymous's avatar
comfyanonymous committed
15
16
17
18


import comfy.samplers
import comfy.sd
comfyanonymous's avatar
comfyanonymous committed
19
20
import comfy.utils

21
import model_management
22
import importlib
comfyanonymous's avatar
comfyanonymous committed
23

comfyanonymous's avatar
comfyanonymous committed
24
25
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
comfyanonymous's avatar
comfyanonymous committed
26
27
28
try:
    import safetensors.torch
    supported_ckpt_extensions += ['.safetensors']
comfyanonymous's avatar
comfyanonymous committed
29
    supported_pt_extensions += ['.safetensors']
comfyanonymous's avatar
comfyanonymous committed
30
31
32
except:
    print("Could not import safetensors, safetensors support disabled.")

33
34
35
36
def recursive_search(directory):  
    result = []
    for root, subdir, file in os.walk(directory, followlinks=True):
        for filepath in file:
37
38
            #we os.path,join directory with a blank string to generate a path separator at the end.
            result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) 
39
40
    return result

comfyanonymous's avatar
comfyanonymous committed
41
42
43
def filter_files_extensions(files, extensions):
    return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))

44
45
46
47

def before_node_execution():
    model_management.throw_exception_if_processing_interrupted()

48
49
def interrupt_processing(value=True):
    model_management.interrupt_current_processing(value)
50

comfyanonymous's avatar
comfyanonymous committed
51
52
53
class CLIPTextEncode:
    @classmethod
    def INPUT_TYPES(s):
54
        return {"required": {"text": ("STRING", {"multiline": True, "dynamic_prompt": True}), "clip": ("CLIP", )}}
comfyanonymous's avatar
comfyanonymous committed
55
56
57
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "encode"

58
59
    CATEGORY = "conditioning"

comfyanonymous's avatar
comfyanonymous committed
60
    def encode(self, clip, text):
comfyanonymous's avatar
comfyanonymous committed
61
62
63
64
65
66
67
68
69
        return ([[clip.encode(text), {}]], )

class ConditioningCombine:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "combine"

70
71
    CATEGORY = "conditioning"

comfyanonymous's avatar
comfyanonymous committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def combine(self, conditioning_1, conditioning_2):
        return (conditioning_1 + conditioning_2, )

class ConditioningSetArea:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"conditioning": ("CONDITIONING", ),
                              "width": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
                              "height": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
                              "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
                              "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
                              "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                             }}
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "append"

88
89
    CATEGORY = "conditioning"

comfyanonymous's avatar
comfyanonymous committed
90
    def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
comfyanonymous's avatar
comfyanonymous committed
91
92
93
94
95
96
97
98
        c = []
        for t in conditioning:
            n = [t[0], t[1].copy()]
            n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
            n[1]['strength'] = strength
            n[1]['min_sigma'] = min_sigma
            n[1]['max_sigma'] = max_sigma
            c.append(n)
comfyanonymous's avatar
comfyanonymous committed
99
        return (c, )
comfyanonymous's avatar
comfyanonymous committed
100
101
102
103
104
105
106
107
108
109
110

class VAEDecode:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decode"

111
112
    CATEGORY = "latent"

comfyanonymous's avatar
comfyanonymous committed
113
    def decode(self, vae, samples):
114
        return (vae.decode(samples["samples"]), )
comfyanonymous's avatar
comfyanonymous committed
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class VAEDecodeTiled:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decode"

    CATEGORY = "_for_testing"

    def decode(self, vae, samples):
        return (vae.decode_tiled(samples["samples"]), )

comfyanonymous's avatar
comfyanonymous committed
131
132
133
134
135
136
137
138
139
140
class VAEEncode:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "encode"

141
142
    CATEGORY = "latent"

comfyanonymous's avatar
comfyanonymous committed
143
    def encode(self, vae, pixels):
144
145
146
147
        x = (pixels.shape[1] // 64) * 64
        y = (pixels.shape[2] // 64) * 64
        if pixels.shape[1] != x or pixels.shape[2] != y:
            pixels = pixels[:,:x,:y,:]
148
149
150
        t = vae.encode(pixels[:,:,:,:3])

        return ({"samples":t}, )
comfyanonymous's avatar
comfyanonymous committed
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class VAEEncodeForInpaint:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "encode"

    CATEGORY = "latent/inpaint"

    def encode(self, vae, pixels, mask):
        x = (pixels.shape[1] // 64) * 64
        y = (pixels.shape[2] // 64) * 64
167
168
        mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]

169
170
171
172
        if pixels.shape[1] != x or pixels.shape[2] != y:
            pixels = pixels[:,:x,:y,:]
            mask = mask[:x,:y]

173
        #grow mask by a few pixels to keep things seamless in latent space
174
        kernel_tensor = torch.ones((1, 1, 6, 6))
175
176
        mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
        m = (1.0 - mask.round())
177
178
        for i in range(3):
            pixels[:,:,:,i] -= 0.5
179
            pixels[:,:,:,i] *= m
180
181
182
            pixels[:,:,:,i] += 0.5
        t = vae.encode(pixels)

183
        return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
comfyanonymous's avatar
comfyanonymous committed
184
185
186
187
188

class CheckpointLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    config_dir = os.path.join(models_dir, "configs")
    ckpt_dir = os.path.join(models_dir, "checkpoints")
189
    embedding_directory = os.path.join(models_dir, "embeddings")
comfyanonymous's avatar
comfyanonymous committed
190
191
192

    @classmethod
    def INPUT_TYPES(s):
193
194
        return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
                              "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
comfyanonymous's avatar
comfyanonymous committed
195
196
197
    RETURN_TYPES = ("MODEL", "CLIP", "VAE")
    FUNCTION = "load_checkpoint"

198
199
    CATEGORY = "loaders"

comfyanonymous's avatar
comfyanonymous committed
200
201
202
    def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
        config_path = os.path.join(self.config_dir, config_name)
        ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
203
        return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
comfyanonymous's avatar
comfyanonymous committed
204

205
206
207
208
209
210
211
212
213
214
215
216
217
class CheckpointLoaderSimple:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    ckpt_dir = os.path.join(models_dir, "checkpoints")

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
                             }}
    RETURN_TYPES = ("MODEL", "CLIP", "VAE")
    FUNCTION = "load_checkpoint"

    CATEGORY = "_for_testing"

218
    def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
219
        ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
220
        out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
221
222
        return out

223
224
225
226
227
228
229
class LoraLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    lora_dir = os.path.join(models_dir, "loras")
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "clip": ("CLIP", ),
230
                              "lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
231
232
233
234
235
236
237
238
239
240
241
242
243
                              "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL", "CLIP")
    FUNCTION = "load_lora"

    CATEGORY = "loaders"

    def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
        lora_path = os.path.join(self.lora_dir, lora_name)
        model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
        return (model_lora, clip_lora)

comfyanonymous's avatar
comfyanonymous committed
244
245
246
247
248
class VAELoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    vae_dir = os.path.join(models_dir, "vae")
    @classmethod
    def INPUT_TYPES(s):
249
        return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
comfyanonymous's avatar
comfyanonymous committed
250
251
252
    RETURN_TYPES = ("VAE",)
    FUNCTION = "load_vae"

253
254
    CATEGORY = "loaders"

comfyanonymous's avatar
comfyanonymous committed
255
256
257
258
259
260
    #TODO: scale factor?
    def load_vae(self, vae_name):
        vae_path = os.path.join(self.vae_dir, vae_name)
        vae = comfy.sd.VAE(ckpt_path=vae_path)
        return (vae,)

comfyanonymous's avatar
comfyanonymous committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
class ControlNetLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    controlnet_dir = os.path.join(models_dir, "controlnet")
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}

    RETURN_TYPES = ("CONTROL_NET",)
    FUNCTION = "load_controlnet"

    CATEGORY = "loaders"

    def load_controlnet(self, control_net_name):
        controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
        controlnet = comfy.sd.load_controlnet(controlnet_path)
        return (controlnet,)

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class DiffControlNetLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    controlnet_dir = os.path.join(models_dir, "controlnet")
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}

    RETURN_TYPES = ("CONTROL_NET",)
    FUNCTION = "load_controlnet"

    CATEGORY = "loaders"

    def load_controlnet(self, model, control_net_name):
        controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
        controlnet = comfy.sd.load_controlnet(controlnet_path, model)
        return (controlnet,)

comfyanonymous's avatar
comfyanonymous committed
296
297
298
299

class ControlNetApply:
    @classmethod
    def INPUT_TYPES(s):
300
301
302
303
304
        return {"required": {"conditioning": ("CONDITIONING", ),
                             "control_net": ("CONTROL_NET", ),
                             "image": ("IMAGE", ),
                             "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
                             }}
comfyanonymous's avatar
comfyanonymous committed
305
306
307
308
309
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "apply_controlnet"

    CATEGORY = "conditioning"

310
    def apply_controlnet(self, conditioning, control_net, image, strength):
comfyanonymous's avatar
comfyanonymous committed
311
312
313
314
315
        c = []
        control_hint = image.movedim(-1,1)
        print(control_hint.shape)
        for t in conditioning:
            n = [t[0], t[1].copy()]
comfyanonymous's avatar
comfyanonymous committed
316
317
318
319
            c_net = control_net.copy().set_cond_hint(control_hint, strength)
            if 'control' in t[1]:
                c_net.set_previous_controlnet(t[1]['control'])
            n[1]['control'] = c_net
comfyanonymous's avatar
comfyanonymous committed
320
321
322
            c.append(n)
        return (c, )

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class T2IAdapterLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}

    RETURN_TYPES = ("CONTROL_NET",)
    FUNCTION = "load_t2i_adapter"

    CATEGORY = "loaders"

    def load_t2i_adapter(self, t2i_adapter_name):
        t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
        t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
        return (t2i_adapter,)
comfyanonymous's avatar
comfyanonymous committed
339

340
341
342
343
344
class CLIPLoader:
    models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
    clip_dir = os.path.join(models_dir, "clip")
    @classmethod
    def INPUT_TYPES(s):
345
        return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
346
347
348
349
350
351
352
353
354
355
356
357
358
                              "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
                             }}
    RETURN_TYPES = ("CLIP",)
    FUNCTION = "load_clip"

    CATEGORY = "loaders"

    def load_clip(self, clip_name, stop_at_clip_layer):
        clip_path = os.path.join(self.clip_dir, clip_name)
        clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
        clip.clip_layer(stop_at_clip_layer)
        return (clip,)

comfyanonymous's avatar
comfyanonymous committed
359
360
361
362
363
364
365
366
367
368
369
370
class EmptyLatentImage:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                              "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                              "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "generate"

371
372
    CATEGORY = "latent"

comfyanonymous's avatar
comfyanonymous committed
373
374
    def generate(self, width, height, batch_size=1):
        latent = torch.zeros([batch_size, 4, height // 8, width // 8])
375
        return ({"samples":latent}, )
comfyanonymous's avatar
comfyanonymous committed
376

comfyanonymous's avatar
comfyanonymous committed
377

comfyanonymous's avatar
comfyanonymous committed
378

comfyanonymous's avatar
comfyanonymous committed
379
380
class LatentUpscale:
    upscale_methods = ["nearest-exact", "bilinear", "area"]
381
    crop_methods = ["disabled", "center"]
comfyanonymous's avatar
comfyanonymous committed
382
383
384
385
386

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
                              "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
387
388
                              "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                              "crop": (s.crop_methods,)}}
comfyanonymous's avatar
comfyanonymous committed
389
390
391
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "upscale"

392
393
    CATEGORY = "latent"

394
    def upscale(self, samples, upscale_method, width, height, crop):
395
        s = samples.copy()
comfyanonymous's avatar
comfyanonymous committed
396
        s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
comfyanonymous's avatar
comfyanonymous committed
397
398
        return (s,)

comfyanonymous's avatar
comfyanonymous committed
399
400
401
402
403
404
405
406
407
408
409
410
class LatentRotate:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT",),
                              "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
                              }}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "rotate"

    CATEGORY = "latent"

    def rotate(self, samples, rotation):
411
        s = samples.copy()
comfyanonymous's avatar
comfyanonymous committed
412
413
414
415
416
417
418
419
        rotate_by = 0
        if rotation.startswith("90"):
            rotate_by = 1
        elif rotation.startswith("180"):
            rotate_by = 2
        elif rotation.startswith("270"):
            rotate_by = 3

420
        s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
comfyanonymous's avatar
comfyanonymous committed
421
        return (s,)
comfyanonymous's avatar
comfyanonymous committed
422
423
424
425
426
427
428
429
430
431
432
433
434

class LatentFlip:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT",),
                              "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
                              }}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "flip"

    CATEGORY = "latent"

    def flip(self, samples, flip_method):
435
        s = samples.copy()
comfyanonymous's avatar
comfyanonymous committed
436
        if flip_method.startswith("x"):
437
            s["samples"] = torch.flip(samples["samples"], dims=[2])
comfyanonymous's avatar
comfyanonymous committed
438
        elif flip_method.startswith("y"):
439
            s["samples"] = torch.flip(samples["samples"], dims=[3])
comfyanonymous's avatar
comfyanonymous committed
440
441

        return (s,)
comfyanonymous's avatar
comfyanonymous committed
442
443
444
445
446
447
448
449

class LatentComposite:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples_to": ("LATENT",),
                              "samples_from": ("LATENT",),
                              "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
                              "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
450
                              "feather": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
comfyanonymous's avatar
comfyanonymous committed
451
452
453
454
455
456
                              }}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "composite"

    CATEGORY = "latent"

457
    def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
comfyanonymous's avatar
comfyanonymous committed
458
459
        x =  x // 8
        y = y // 8
460
        feather = feather // 8
461
462
463
464
        samples_out = samples_to.copy()
        s = samples_to["samples"].clone()
        samples_to = samples_to["samples"]
        samples_from = samples_from["samples"]
465
466
467
        if feather == 0:
            s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
        else:
468
469
            samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
            mask = torch.ones_like(samples_from)
470
471
472
473
474
475
476
477
478
479
480
481
            for t in range(feather):
                if y != 0:
                    mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))

                if y + samples_from.shape[2] < samples_to.shape[2]:
                    mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
                if x != 0:
                    mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
                if x + samples_from.shape[3] < samples_to.shape[3]:
                    mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
            rev_mask = torch.ones_like(mask) - mask
            s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
482
483
        samples_out["samples"] = s
        return (samples_out,)
comfyanonymous's avatar
comfyanonymous committed
484

comfyanonymous's avatar
comfyanonymous committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class LatentCrop:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT",),
                              "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                              "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                              "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
                              "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
                              }}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "crop"

    CATEGORY = "latent"

    def crop(self, samples, width, height, x, y):
500
501
        s = samples.copy()
        samples = samples['samples']
comfyanonymous's avatar
comfyanonymous committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        x =  x // 8
        y = y // 8

        #enfonce minimum size of 64
        if x > (samples.shape[3] - 8):
            x = samples.shape[3] - 8
        if y > (samples.shape[2] - 8):
            y = samples.shape[2] - 8

        new_height = height // 8
        new_width = width // 8
        to_x = new_width + x
        to_y = new_height + y
        def enforce_image_dim(d, to_d, max_d):
            if to_d > max_d:
                leftover = (to_d - max_d) % 8
                to_d = max_d
                d -= leftover
            return (d, to_d)

        #make sure size is always multiple of 64
        x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
        y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
525
        s['samples'] = samples[:,:,y:to_y, x:to_x]
comfyanonymous's avatar
comfyanonymous committed
526
527
        return (s,)

528
529
530
531
532
533
534
535
536
class SetLatentNoiseMask:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "samples": ("LATENT",),
                              "mask": ("MASK",),
                              }}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "set_mask"

537
    CATEGORY = "latent/inpaint"
538
539
540
541
542
543
544
545
546
547
548

    def set_mask(self, samples, mask):
        s = samples.copy()
        s["noise_mask"] = mask
        return (s,)


def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
    latent_image = latent["samples"]
    noise_mask = None

comfyanonymous's avatar
comfyanonymous committed
549
550
551
552
553
    if disable_noise:
        noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
    else:
        noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")

554
555
556
    if "noise_mask" in latent:
        noise_mask = latent['noise_mask']
        noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
557
        noise_mask = noise_mask.round()
558
559
560
561
        noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
        noise_mask = torch.cat([noise_mask] * noise.shape[0])
        noise_mask = noise_mask.to(device)

562
    real_model = None
563
564
565
566
567
    if device != "cpu":
        model_management.load_model_gpu(model)
        real_model = model.model
    else:
        #TODO: cpu support
568
        real_model = model.patch_model()
569
570
571
572
573
574
    noise = noise.to(device)
    latent_image = latent_image.to(device)

    positive_copy = []
    negative_copy = []

comfyanonymous's avatar
comfyanonymous committed
575
    control_nets = []
576
577
578
579
580
    for p in positive:
        t = p[0]
        if t.shape[0] < noise.shape[0]:
            t = torch.cat([t] * noise.shape[0])
        t = t.to(device)
comfyanonymous's avatar
comfyanonymous committed
581
582
        if 'control' in p[1]:
            control_nets += [p[1]['control']]
583
584
585
586
587
588
        positive_copy += [[t] + p[1:]]
    for n in negative:
        t = n[0]
        if t.shape[0] < noise.shape[0]:
            t = torch.cat([t] * noise.shape[0])
        t = t.to(device)
comfyanonymous's avatar
comfyanonymous committed
589
590
        if 'control' in p[1]:
            control_nets += [p[1]['control']]
591
592
        negative_copy += [[t] + n[1:]]

comfyanonymous's avatar
comfyanonymous committed
593
594
595
596
    control_net_models = []
    for x in control_nets:
        control_net_models += x.get_control_models()
    model_management.load_controlnet_gpu(control_net_models)
comfyanonymous's avatar
comfyanonymous committed
597

598
599
600
601
602
603
    if sampler_name in comfy.samplers.KSampler.SAMPLERS:
        sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
    else:
        #other samplers
        pass

604
    samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
605
    samples = samples.cpu()
comfyanonymous's avatar
comfyanonymous committed
606
607
    for c in control_nets:
        c.cleanup()
comfyanonymous's avatar
comfyanonymous committed
608

609
610
611
    out = latent.copy()
    out["samples"] = samples
    return (out, )
comfyanonymous's avatar
comfyanonymous committed
612

comfyanonymous's avatar
comfyanonymous committed
613
614
615
616
617
618
class KSampler:
    def __init__(self, device="cuda"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
comfyanonymous's avatar
comfyanonymous committed
619
        return {"required":
comfyanonymous's avatar
comfyanonymous committed
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
                    {"model": ("MODEL",),
                    "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
                    "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
                    "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
                    "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
                    "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
                    "positive": ("CONDITIONING", ),
                    "negative": ("CONDITIONING", ),
                    "latent_image": ("LATENT", ),
                    "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                    }}

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "sample"

635
636
    CATEGORY = "sampling"

comfyanonymous's avatar
comfyanonymous committed
637
    def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
comfyanonymous's avatar
comfyanonymous committed
638
        return common_ksampler(self.device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
comfyanonymous's avatar
comfyanonymous committed
639

comfyanonymous's avatar
comfyanonymous committed
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class KSamplerAdvanced:
    def __init__(self, device="cuda"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {"model": ("MODEL",),
                    "add_noise": (["enable", "disable"], ),
                    "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
                    "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
                    "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
                    "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
                    "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
                    "positive": ("CONDITIONING", ),
                    "negative": ("CONDITIONING", ),
                    "latent_image": ("LATENT", ),
                    "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
                    "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
                    "return_with_leftover_noise": (["disable", "enable"], ),
                    }}

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "sample"

    CATEGORY = "sampling"
comfyanonymous's avatar
comfyanonymous committed
666

comfyanonymous's avatar
comfyanonymous committed
667
668
669
670
671
672
673
674
    def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
        force_full_denoise = True
        if return_with_leftover_noise == "enable":
            force_full_denoise = False
        disable_noise = False
        if add_noise == "disable":
            disable_noise = True
        return common_ksampler(self.device, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
comfyanonymous's avatar
comfyanonymous committed
675
676
677
678
679
680
681
682

class SaveImage:
    def __init__(self):
        self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")

    @classmethod
    def INPUT_TYPES(s):
        return {"required": 
683
684
                    {"images": ("IMAGE", ),
                     "filename_prefix": ("STRING", {"default": "ComfyUI"})},
pythongosssss's avatar
pythongosssss committed
685
                "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
comfyanonymous's avatar
comfyanonymous committed
686
687
688
689
690
691
692
                }

    RETURN_TYPES = ()
    FUNCTION = "save_images"

    OUTPUT_NODE = True

693
694
    CATEGORY = "image"

pythongosssss's avatar
pythongosssss committed
695
    def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
696
697
698
699
700
701
702
703
704
705
706
707
        def map_filename(filename):
            prefix_len = len(filename_prefix)
            prefix = filename[:prefix_len + 1]
            try:
                digits = int(filename[prefix_len + 1:].split('_')[0])
            except:
                digits = 0
            return (digits, prefix)
        try:
            counter = max(filter(lambda a: a[1][:-1] == filename_prefix and a[1][-1] == "_", map(map_filename, os.listdir(self.output_dir))))[0] + 1
        except ValueError:
            counter = 1
708
709
710
        except FileNotFoundError:
            os.mkdir(self.output_dir)
            counter = 1
pythongosssss's avatar
pythongosssss committed
711
712

        paths = list()
comfyanonymous's avatar
comfyanonymous committed
713
714
715
716
717
718
719
720
721
        for image in images:
            i = 255. * image.cpu().numpy()
            img = Image.fromarray(i.astype(np.uint8))
            metadata = PngInfo()
            if prompt is not None:
                metadata.add_text("prompt", json.dumps(prompt))
            if extra_pnginfo is not None:
                for x in extra_pnginfo:
                    metadata.add_text(x, json.dumps(extra_pnginfo[x]))
pythongosssss's avatar
pythongosssss committed
722
723
            file = f"{filename_prefix}_{counter:05}_.png"
            img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True)
pythongosssss's avatar
pythongosssss committed
724
            paths.append(file)
725
            counter += 1
pythongosssss's avatar
pythongosssss committed
726
        return { "ui": { "images": paths } }
comfyanonymous's avatar
comfyanonymous committed
727

728
729
730
731
732
class LoadImage:
    input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
733
                    {"image": (sorted(os.listdir(s.input_dir)), )},
734
                }
735
736

    CATEGORY = "image"
737
738
739
740
741

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "load_image"
    def load_image(self, image):
        image_path = os.path.join(self.input_dir, image)
742
743
        i = Image.open(image_path)
        image = i.convert("RGB")
744
        image = np.array(image).astype(np.float32) / 255.0
745
746
        image = torch.from_numpy(image)[None,]
        return (image,)
747

748
749
750
751
752
753
754
755
    @classmethod
    def IS_CHANGED(s, image):
        image_path = os.path.join(s.input_dir, image)
        m = hashlib.sha256()
        with open(image_path, 'rb') as f:
            m.update(f.read())
        return m.digest().hex()

756
757
758
759
760
class LoadImageMask:
    input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
761
                    {"image": (sorted(os.listdir(s.input_dir)), ),
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
                    "channel": (["alpha", "red", "green", "blue"], ),}
                }

    CATEGORY = "image"

    RETURN_TYPES = ("MASK",)
    FUNCTION = "load_image"
    def load_image(self, image, channel):
        image_path = os.path.join(self.input_dir, image)
        i = Image.open(image_path)
        mask = None
        c = channel[0].upper()
        if c in i.getbands():
            mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
            mask = torch.from_numpy(mask)
            if c == 'A':
                mask = 1. - mask
        else:
            mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
        return (mask,)

    @classmethod
    def IS_CHANGED(s, image, channel):
        image_path = os.path.join(s.input_dir, image)
        m = hashlib.sha256()
        with open(image_path, 'rb') as f:
            m.update(f.read())
        return m.digest().hex()

comfyanonymous's avatar
comfyanonymous committed
791
792
793
794
795
796
797
798
799
800
801
802
803
804
class ImageScale:
    upscale_methods = ["nearest-exact", "bilinear", "area"]
    crop_methods = ["disabled", "center"]

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
                              "width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
                              "height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
                              "crop": (s.crop_methods,)}}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "upscale"

    CATEGORY = "image"
805

comfyanonymous's avatar
comfyanonymous committed
806
807
    def upscale(self, image, upscale_method, width, height, crop):
        samples = image.movedim(-1,1)
comfyanonymous's avatar
comfyanonymous committed
808
        s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
comfyanonymous's avatar
comfyanonymous committed
809
810
        s = s.movedim(1,-1)
        return (s,)
comfyanonymous's avatar
comfyanonymous committed
811

812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
class ImageInvert:

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "image": ("IMAGE",)}}

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "invert"

    CATEGORY = "image"

    def invert(self, image):
        s = 1.0 - image
        return (s,)


comfyanonymous's avatar
comfyanonymous committed
828
829
830
831
832
833
NODE_CLASS_MAPPINGS = {
    "KSampler": KSampler,
    "CheckpointLoader": CheckpointLoader,
    "CLIPTextEncode": CLIPTextEncode,
    "VAEDecode": VAEDecode,
    "VAEEncode": VAEEncode,
834
    "VAEEncodeForInpaint": VAEEncodeForInpaint,
comfyanonymous's avatar
comfyanonymous committed
835
836
837
838
    "VAELoader": VAELoader,
    "EmptyLatentImage": EmptyLatentImage,
    "LatentUpscale": LatentUpscale,
    "SaveImage": SaveImage,
comfyanonymous's avatar
comfyanonymous committed
839
    "LoadImage": LoadImage,
840
    "LoadImageMask": LoadImageMask,
comfyanonymous's avatar
comfyanonymous committed
841
    "ImageScale": ImageScale,
842
    "ImageInvert": ImageInvert,
comfyanonymous's avatar
comfyanonymous committed
843
844
    "ConditioningCombine": ConditioningCombine,
    "ConditioningSetArea": ConditioningSetArea,
comfyanonymous's avatar
comfyanonymous committed
845
    "KSamplerAdvanced": KSamplerAdvanced,
846
    "SetLatentNoiseMask": SetLatentNoiseMask,
comfyanonymous's avatar
comfyanonymous committed
847
    "LatentComposite": LatentComposite,
comfyanonymous's avatar
comfyanonymous committed
848
    "LatentRotate": LatentRotate,
comfyanonymous's avatar
comfyanonymous committed
849
    "LatentFlip": LatentFlip,
comfyanonymous's avatar
comfyanonymous committed
850
    "LatentCrop": LatentCrop,
851
    "LoraLoader": LoraLoader,
852
    "CLIPLoader": CLIPLoader,
comfyanonymous's avatar
comfyanonymous committed
853
854
    "ControlNetApply": ControlNetApply,
    "ControlNetLoader": ControlNetLoader,
855
    "DiffControlNetLoader": DiffControlNetLoader,
856
    "T2IAdapterLoader": T2IAdapterLoader,
857
    "VAEDecodeTiled": VAEDecodeTiled,
858
    "CheckpointLoaderSimple": CheckpointLoaderSimple,
comfyanonymous's avatar
comfyanonymous committed
859
860
}

861
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
Hacker 17082006's avatar
Hacker 17082006 committed
862
def load_custom_nodes():
863
    possible_modules = os.listdir(CUSTOM_NODE_PATH)
864
    if "__pycache__" in possible_modules:
Hacker 17082006's avatar
.  
Hacker 17082006 committed
865
        possible_modules.remove("__pycache__")
866

Hacker 17082006's avatar
Hacker 17082006 committed
867
    for possible_module in possible_modules:
868
869
        module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
        if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
870

871
        module_name = possible_module
Hacker 17082006's avatar
Hacker 17082006 committed
872
        try:
873
            if os.path.isfile(module_path):
874
                module_spec = importlib.util.spec_from_file_location(module_name, module_path)
875
            else:
876
                module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
877
            module = importlib.util.module_from_spec(module_spec)
878
            sys.modules[module_name] = module
879
            module_spec.loader.exec_module(module)
880
            if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
881
                NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
Hacker 17082006's avatar
Hacker 17082006 committed
882
            else:
Hacker 17082006's avatar
Hacker 17082006 committed
883
                print(f"Skip {possible_module} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
884
885
886
        except Exception as e:
            print(traceback.format_exc())
            print(f"Cannot import {possible_module} module for custom nodes:", e)
Hacker 17082006's avatar
Hacker 17082006 committed
887
888

load_custom_nodes()