nodes_compositing.py 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import torch
import comfy.utils
from enum import Enum


class PorterDuffMode(Enum):
    ADD = 0
    CLEAR = 1
    DARKEN = 2
    DST = 3
    DST_ATOP = 4
    DST_IN = 5
    DST_OUT = 6
    DST_OVER = 7
    LIGHTEN = 8
    MULTIPLY = 9
    OVERLAY = 10
    SCREEN = 11
    SRC = 12
    SRC_ATOP = 13
    SRC_IN = 14
    SRC_OUT = 15
    SRC_OVER = 16
    XOR = 17


def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
    if mode == PorterDuffMode.ADD:
        out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
        out_image = torch.clamp(src_image + dst_image, 0, 1)
    elif mode == PorterDuffMode.CLEAR:
        out_alpha = torch.zeros_like(dst_alpha)
        out_image = torch.zeros_like(dst_image)
    elif mode == PorterDuffMode.DARKEN:
        out_alpha = src_alpha + dst_alpha  - src_alpha * dst_alpha
        out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
    elif mode == PorterDuffMode.DST:
        out_alpha = dst_alpha
        out_image = dst_image
    elif mode == PorterDuffMode.DST_ATOP:
        out_alpha = src_alpha
        out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
    elif mode == PorterDuffMode.DST_IN:
        out_alpha = src_alpha * dst_alpha
        out_image = dst_image * src_alpha
    elif mode == PorterDuffMode.DST_OUT:
        out_alpha = (1 - src_alpha) * dst_alpha
        out_image = (1 - src_alpha) * dst_image
    elif mode == PorterDuffMode.DST_OVER:
        out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
        out_image = dst_image + (1 - dst_alpha) * src_image
    elif mode == PorterDuffMode.LIGHTEN:
        out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
        out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
    elif mode == PorterDuffMode.MULTIPLY:
        out_alpha = src_alpha * dst_alpha
        out_image = src_image * dst_image
    elif mode == PorterDuffMode.OVERLAY:
        out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
        out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
            src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
    elif mode == PorterDuffMode.SCREEN:
        out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
        out_image = src_image + dst_image - src_image * dst_image
    elif mode == PorterDuffMode.SRC:
        out_alpha = src_alpha
        out_image = src_image
    elif mode == PorterDuffMode.SRC_ATOP:
        out_alpha = dst_alpha
        out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
    elif mode == PorterDuffMode.SRC_IN:
        out_alpha = src_alpha * dst_alpha
        out_image = src_image * dst_alpha
    elif mode == PorterDuffMode.SRC_OUT:
        out_alpha = (1 - dst_alpha) * src_alpha
        out_image = (1 - dst_alpha) * src_image
    elif mode == PorterDuffMode.SRC_OVER:
        out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
        out_image = src_image + (1 - src_alpha) * dst_image
    elif mode == PorterDuffMode.XOR:
        out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
        out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
    else:
        out_alpha = None
        out_image = None
    return out_image, out_alpha


class PorterDuffImageComposite:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "source": ("IMAGE",),
96
                "source_alpha": ("MASK",),
97
                "destination": ("IMAGE",),
98
                "destination_alpha": ("MASK",),
99
100
101
102
                "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
            },
        }

103
    RETURN_TYPES = ("IMAGE", "MASK")
104
    FUNCTION = "composite"
comfyanonymous's avatar
comfyanonymous committed
105
    CATEGORY = "mask/compositing"
106
107
108
109
110
111
112
113
114
115

    def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
        batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
        out_images = []
        out_alphas = []

        for i in range(batch_size):
            src_image = source[i]
            dst_image = destination[i]

116
117
            assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels

118
119
120
            src_alpha = source_alpha[i].unsqueeze(2)
            dst_alpha = destination_alpha[i].unsqueeze(2)

121
122
            if dst_alpha.shape[:2] != dst_image.shape[:2]:
                upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
123
124
125
                upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
                dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
            if src_image.shape != dst_image.shape:
126
                upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
127
128
129
                upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
                src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
            if src_alpha.shape != dst_alpha.shape:
130
                upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
                src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)

            out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])

            out_images.append(out_image)
            out_alphas.append(out_alpha.squeeze(2))

        result = (torch.stack(out_images), torch.stack(out_alphas))
        return result


class SplitImageWithAlpha:
    @classmethod
    def INPUT_TYPES(s):
        return {
                "required": {
                    "image": ("IMAGE",),
                }
        }

comfyanonymous's avatar
comfyanonymous committed
152
    CATEGORY = "mask/compositing"
153
    RETURN_TYPES = ("IMAGE", "MASK")
154
155
156
157
    FUNCTION = "split_image_with_alpha"

    def split_image_with_alpha(self, image: torch.Tensor):
        out_images = [i[:,:,:3] for i in image]
158
        out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
159
160
161
162
163
164
165
166
167
168
        result = (torch.stack(out_images), torch.stack(out_alphas))
        return result


class JoinImageWithAlpha:
    @classmethod
    def INPUT_TYPES(s):
        return {
                "required": {
                    "image": ("IMAGE",),
169
                    "alpha": ("MASK",),
170
171
172
                }
        }

comfyanonymous's avatar
comfyanonymous committed
173
    CATEGORY = "mask/compositing"
174
175
176
177
178
179
180
181
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "join_image_with_alpha"

    def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
        batch_size = min(len(image), len(alpha))
        out_images = []

        for i in range(batch_size):
182
           out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        result = (torch.stack(out_images),)
        return result


NODE_CLASS_MAPPINGS = {
    "PorterDuffImageComposite": PorterDuffImageComposite,
    "SplitImageWithAlpha": SplitImageWithAlpha,
    "JoinImageWithAlpha": JoinImageWithAlpha,
}


NODE_DISPLAY_NAME_MAPPINGS = {
    "PorterDuffImageComposite": "Porter-Duff Image Composite",
    "SplitImageWithAlpha": "Split Image with Alpha",
    "JoinImageWithAlpha": "Join Image with Alpha",
}