Commit 022a9f27 authored by mligaintart's avatar mligaintart
Browse files

Adds masking to Latent Composite, and provides new masking utilities to

allow better compositing.
parent d5cce834
import torch
from nodes import MAX_RESOLUTION
class LatentCompositeMasked:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
CATEGORY = "latent"
def composite(self, destination, source, x, y, mask = None):
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,)
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
return (output,)
class MaskToImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert"
def convert(self, mask):
image = torch.cat([torch.reshape(mask.clone(), [1, mask.shape[0], mask.shape[1], 1,])] * 3, 3)
return (image,)
class SolidMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "solid"
def solid(self, value, width, height):
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
return (out,)
class InvertMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "invert"
def invert(self, mask):
out = 1.0 - mask
return (out,)
class CropMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
out = mask[y:y + height, x:x + width]
return (out,)
class MaskComposite:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"destination": ("MASK",),
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
output = destination.clone()
left, top = (x, y,)
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right]
match operation:
case "multiply":
output[top:bottom, left:right] = destination_portion * source_portion
case "add":
output[top:bottom, left:right] = destination_portion + source_portion
case "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
output = torch.clamp(output, 0.0, 1.0)
return (output,)
class FeatherMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
output = mask.clone()
left = min(left, output.shape[1])
right = min(right, output.shape[1])
top = min(top, output.shape[0])
bottom = min(bottom, output.shape[0])
for x in range(left):
feather_rate = (x + 1.0) / left
output[:, x] *= feather_rate
for x in range(right):
feather_rate = (x + 1) / right
output[:, -x] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
output[y, :] *= feather_rate
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[-y, :] *= feather_rate
return (output,)
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
"MaskToImage": MaskToImage,
"SolidMask": SolidMask,
"InvertMask": InvertMask,
"CropMask": CropMask,
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
}
......@@ -553,44 +553,64 @@ class LatentFlip:
class LatentComposite:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",),
"samples_from": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
return {
"required": {
"samples_to": ("LATENT",),
"samples_from": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
CATEGORY = "latent"
def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
x = x // 8
y = y // 8
def composite(self, samples_to, samples_from, x, y, feather):
output = samples_to.copy()
destination = samples_to["samples"].clone()
source = samples_from["samples"]
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
feather = feather // 8
samples_out = samples_to.copy()
s = samples_to["samples"].clone()
samples_to = samples_to["samples"]
samples_from = samples_from["samples"]
if feather == 0:
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
else:
samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(samples_from)
for t in range(feather):
if y != 0:
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
if y + samples_from.shape[2] < samples_to.shape[2]:
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
if x != 0:
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
if x + samples_from.shape[3] < samples_to.shape[3]:
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
rev_mask = torch.ones_like(mask) - mask
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
samples_out["samples"] = s
return (samples_out,)
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,)
mask = torch.ones_like(source)
for f in range(feather):
feather_rate = (f + 1.0) / feather
if left > 0:
mask[:, :, :, f] *= feather_rate
if right < destination.shape[3] - 1:
mask[:, :, :, -f] *= feather_rate
if top > 0:
mask[:, :, f, :] *= feather_rate
if bottom < destination.shape[2] - 1:
mask[:, :, -f, :] *= feather_rate
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
return (output,)
class LatentCrop:
@classmethod
......@@ -907,7 +927,7 @@ class LoadImageMask:
"channel": (["alpha", "red", "green", "blue"], ),}
}
CATEGORY = "image"
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
......@@ -1114,3 +1134,4 @@ def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment