"test/gemm/gemm_dl_fp32.cpp" did not exist on "5b178874a1b2a1cae217e87e1988ab92a40d71b8"
Commit 022a9f27 authored by mligaintart's avatar mligaintart
Browse files

Adds masking to Latent Composite, and provides new masking utilities to

allow better compositing.
parent d5cce834
import torch
from nodes import MAX_RESOLUTION
class LatentCompositeMasked:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
CATEGORY = "latent"
def composite(self, destination, source, x, y, mask = None):
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,)
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
return (output,)
class MaskToImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert"
def convert(self, mask):
image = torch.cat([torch.reshape(mask.clone(), [1, mask.shape[0], mask.shape[1], 1,])] * 3, 3)
return (image,)
class SolidMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "solid"
def solid(self, value, width, height):
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
return (out,)
class InvertMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "invert"
def invert(self, mask):
out = 1.0 - mask
return (out,)
class CropMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
out = mask[y:y + height, x:x + width]
return (out,)
class MaskComposite:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"destination": ("MASK",),
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
output = destination.clone()
left, top = (x, y,)
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right]
match operation:
case "multiply":
output[top:bottom, left:right] = destination_portion * source_portion
case "add":
output[top:bottom, left:right] = destination_portion + source_portion
case "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
output = torch.clamp(output, 0.0, 1.0)
return (output,)
class FeatherMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
output = mask.clone()
left = min(left, output.shape[1])
right = min(right, output.shape[1])
top = min(top, output.shape[0])
bottom = min(bottom, output.shape[0])
for x in range(left):
feather_rate = (x + 1.0) / left
output[:, x] *= feather_rate
for x in range(right):
feather_rate = (x + 1) / right
output[:, -x] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
output[y, :] *= feather_rate
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[-y, :] *= feather_rate
return (output,)
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
"MaskToImage": MaskToImage,
"SolidMask": SolidMask,
"InvertMask": InvertMask,
"CropMask": CropMask,
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
}
...@@ -553,44 +553,64 @@ class LatentFlip: ...@@ -553,44 +553,64 @@ class LatentFlip:
class LatentComposite: class LatentComposite:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",), return {
"samples_from": ("LATENT",), "required": {
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "samples_to": ("LATENT",),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "samples_from": ("LATENT",),
"feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}} "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "composite" FUNCTION = "composite"
CATEGORY = "latent" CATEGORY = "latent"
def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): def composite(self, samples_to, samples_from, x, y, feather):
x = x // 8 output = samples_to.copy()
y = y // 8 destination = samples_to["samples"].clone()
source = samples_from["samples"]
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
feather = feather // 8 feather = feather // 8
samples_out = samples_to.copy()
s = samples_to["samples"].clone()
samples_to = samples_to["samples"]
samples_from = samples_from["samples"] # calculate the bounds of the source that will be overlapping the destination
if feather == 0: # this prevents the source trying to overwrite latent pixels that are out of bounds
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] # of the destination
else: visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,)
samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(samples_from) mask = torch.ones_like(source)
for t in range(feather):
if y != 0: for f in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) feather_rate = (f + 1.0) / feather
if y + samples_from.shape[2] < samples_to.shape[2]: if left > 0:
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) mask[:, :, :, f] *= feather_rate
if x != 0:
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) if right < destination.shape[3] - 1:
if x + samples_from.shape[3] < samples_to.shape[3]: mask[:, :, :, -f] *= feather_rate
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
rev_mask = torch.ones_like(mask) - mask if top > 0:
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask mask[:, :, f, :] *= feather_rate
samples_out["samples"] = s
return (samples_out,) if bottom < destination.shape[2] - 1:
mask[:, :, -f, :] *= feather_rate
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
return (output,)
class LatentCrop: class LatentCrop:
@classmethod @classmethod
...@@ -907,7 +927,7 @@ class LoadImageMask: ...@@ -907,7 +927,7 @@ class LoadImageMask:
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
CATEGORY = "image" CATEGORY = "mask"
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
...@@ -1114,3 +1134,4 @@ def init_custom_nodes(): ...@@ -1114,3 +1134,4 @@ def init_custom_nodes():
load_custom_nodes() load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment