nodes_post_processing.py 9.68 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
5
import math
comfyanonymous's avatar
comfyanonymous committed
6
7

import comfy.utils
comfyanonymous's avatar
comfyanonymous committed
8
import comfy.model_management
comfyanonymous's avatar
comfyanonymous committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class Blend:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image1": ("IMAGE",),
                "image2": ("IMAGE",),
                "blend_factor": ("FLOAT", {
                    "default": 0.5,
                    "min": 0.0,
                    "max": 1.0,
                    "step": 0.01
                }),
matt3o's avatar
matt3o committed
27
                "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
comfyanonymous's avatar
comfyanonymous committed
28
29
30
31
32
33
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "blend_images"

34
    CATEGORY = "image/postprocessing"
comfyanonymous's avatar
comfyanonymous committed
35
36

    def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
37
        image2 = image2.to(image1.device)
comfyanonymous's avatar
comfyanonymous committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        if image1.shape != image2.shape:
            image2 = image2.permute(0, 3, 1, 2)
            image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
            image2 = image2.permute(0, 2, 3, 1)

        blended_image = self.blend_mode(image1, image2, blend_mode)
        blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
        blended_image = torch.clamp(blended_image, 0, 1)
        return (blended_image,)

    def blend_mode(self, img1, img2, mode):
        if mode == "normal":
            return img2
        elif mode == "multiply":
            return img1 * img2
        elif mode == "screen":
            return 1 - (1 - img1) * (1 - img2)
        elif mode == "overlay":
            return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
        elif mode == "soft_light":
            return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
matt3o's avatar
matt3o committed
59
60
        elif mode == "difference":
            return img1 - img2
comfyanonymous's avatar
comfyanonymous committed
61
62
63
64
65
66
        else:
            raise ValueError(f"Unsupported blend mode: {mode}")

    def g(self, x):
        return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))

67
68
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
    x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
BlenderNeko's avatar
BlenderNeko committed
69
70
71
72
    d = torch.sqrt(x * x + y * y)
    g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
    return g / g.sum()

comfyanonymous's avatar
comfyanonymous committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class Blur:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "blur_radius": ("INT", {
                    "default": 1,
                    "min": 1,
                    "max": 31,
                    "step": 1
                }),
                "sigma": ("FLOAT", {
                    "default": 1.0,
                    "min": 0.1,
                    "max": 10.0,
                    "step": 0.1
                }),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "blur"

100
    CATEGORY = "image/postprocessing"
comfyanonymous's avatar
comfyanonymous committed
101
102
103
104
105

    def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
        if blur_radius == 0:
            return (image,)

comfyanonymous's avatar
comfyanonymous committed
106
        image = image.to(comfy.model_management.get_torch_device())
comfyanonymous's avatar
comfyanonymous committed
107
108
109
        batch_size, height, width, channels = image.shape

        kernel_size = blur_radius * 2 + 1
110
        kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
comfyanonymous's avatar
comfyanonymous committed
111
112

        image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
BlenderNeko's avatar
BlenderNeko committed
113
        padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
comfyanonymous's avatar
comfyanonymous committed
114
        blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
comfyanonymous's avatar
comfyanonymous committed
115
116
        blurred = blurred.permute(0, 2, 3, 1)

comfyanonymous's avatar
comfyanonymous committed
117
        return (blurred.to(comfy.model_management.intermediate_device()),)
comfyanonymous's avatar
comfyanonymous committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

class Quantize:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "colors": ("INT", {
                    "default": 256,
                    "min": 1,
                    "max": 256,
                    "step": 1
                }),
134
                "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
comfyanonymous's avatar
comfyanonymous committed
135
136
137
138
139
140
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "quantize"

141
    CATEGORY = "image/postprocessing"
comfyanonymous's avatar
comfyanonymous committed
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def bayer(im, pal_im, order):
        def normalized_bayer_matrix(n):
            if n == 0:
                return np.zeros((1,1), "float32")
            else:
                q = 4 ** n
                m = q * normalized_bayer_matrix(n - 1)
                return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q

        num_colors = len(pal_im.getpalette()) // 3
        spread = 2 * 256 / num_colors
        bayer_n = int(math.log2(order))
        bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)

        result = torch.from_numpy(np.array(im).astype(np.float32))
        tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
        th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
        tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
        result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
        result = result.to(dtype=torch.uint8)

        im = Image.fromarray(result.cpu().numpy())
        im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
        return im

    def quantize(self, image: torch.Tensor, colors: int, dither: str):
comfyanonymous's avatar
comfyanonymous committed
169
170
171
172
        batch_size, height, width, _ = image.shape
        result = torch.zeros_like(image)

        for b in range(batch_size):
173
174
175
            im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')

            pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
comfyanonymous's avatar
comfyanonymous committed
176

177
178
179
180
181
182
183
            if dither == "none":
                quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
            elif dither == "floyd-steinberg":
                quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
            elif dither.startswith("bayer"):
                order = int(dither.split('-')[-1])
                quantized_image = Quantize.bayer(im, pal_im, order)
comfyanonymous's avatar
comfyanonymous committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

            quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
            result[b] = quantized_array

        return (result,)

class Sharpen:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "sharpen_radius": ("INT", {
                    "default": 1,
                    "min": 1,
                    "max": 31,
                    "step": 1
                }),
BlenderNeko's avatar
BlenderNeko committed
205
                "sigma": ("FLOAT", {
comfyanonymous's avatar
comfyanonymous committed
206
207
                    "default": 1.0,
                    "min": 0.1,
BlenderNeko's avatar
BlenderNeko committed
208
                    "max": 10.0,
209
                    "step": 0.01
BlenderNeko's avatar
BlenderNeko committed
210
211
212
213
                }),
                "alpha": ("FLOAT", {
                    "default": 1.0,
                    "min": 0.0,
comfyanonymous's avatar
comfyanonymous committed
214
                    "max": 5.0,
215
                    "step": 0.01
comfyanonymous's avatar
comfyanonymous committed
216
217
218
219
220
221
222
                }),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "sharpen"

223
    CATEGORY = "image/postprocessing"
comfyanonymous's avatar
comfyanonymous committed
224

BlenderNeko's avatar
BlenderNeko committed
225
    def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
comfyanonymous's avatar
comfyanonymous committed
226
227
228
229
        if sharpen_radius == 0:
            return (image,)

        batch_size, height, width, channels = image.shape
comfyanonymous's avatar
comfyanonymous committed
230
        image = image.to(comfy.model_management.get_torch_device())
comfyanonymous's avatar
comfyanonymous committed
231
232

        kernel_size = sharpen_radius * 2 + 1
233
        kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
comfyanonymous's avatar
comfyanonymous committed
234
        center = kernel_size // 2
BlenderNeko's avatar
BlenderNeko committed
235
        kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
comfyanonymous's avatar
comfyanonymous committed
236
237
238
        kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

        tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
BlenderNeko's avatar
BlenderNeko committed
239
240
        tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
        sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
comfyanonymous's avatar
comfyanonymous committed
241
242
243
244
        sharpened = sharpened.permute(0, 2, 3, 1)

        result = torch.clamp(sharpened, 0, 1)

comfyanonymous's avatar
comfyanonymous committed
245
        return (result.to(comfy.model_management.intermediate_device()),)
comfyanonymous's avatar
comfyanonymous committed
246

247
class ImageScaleToTotalPixels:
248
    upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    crop_methods = ["disabled", "center"]

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
                              "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
                            }}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "upscale"

    CATEGORY = "image/upscaling"

    def upscale(self, image, upscale_method, megapixels):
        samples = image.movedim(-1,1)
        total = int(megapixels * 1024 * 1024)

        scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
        width = round(samples.shape[3] * scale_by)
        height = round(samples.shape[2] * scale_by)

        s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
        s = s.movedim(1,-1)
        return (s,)

comfyanonymous's avatar
comfyanonymous committed
273
NODE_CLASS_MAPPINGS = {
274
275
276
277
    "ImageBlend": Blend,
    "ImageBlur": Blur,
    "ImageQuantize": Quantize,
    "ImageSharpen": Sharpen,
278
    "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
comfyanonymous's avatar
comfyanonymous committed
279
}