"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "31d7a2c54e41923ec7d9fae558238599060d3bca"
Commit 046b4fe0 authored by comfyanonymous's avatar comfyanonymous
Browse files

Support batches of masks in mask composite nodes.

parent ba7dfd60
import numpy as np
from scipy.ndimage import grey_dilation
import torch
import comfy.utils
from nodes import MAX_RESOLUTION
......@@ -8,6 +9,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
......@@ -18,8 +21,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
......@@ -122,7 +125,7 @@ class ImageToMask:
def image_to_mask(self, image, channel):
channels = ["red", "green", "blue"]
mask = image[0, :, :, channels.index(channel)]
mask = image[:, :, :, channels.index(channel)]
return (mask,)
class ImageColorToMask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment