Commit 35f636b6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Expose grow_mask_by in VAEEncodeForInpaint.

The mask is dilated by grow_mask_by pixels after being applied to the pixel
space image. This helps reduce seams caused by inpainting. Higher value
means less seams.
parent 9c335a55
...@@ -5,6 +5,7 @@ import sys ...@@ -5,6 +5,7 @@ import sys
import json import json
import hashlib import hashlib
import traceback import traceback
import math
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
...@@ -223,13 +224,13 @@ class VAEEncodeForInpaint: ...@@ -223,13 +224,13 @@ class VAEEncodeForInpaint:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 64) * 64
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
...@@ -240,8 +241,14 @@ class VAEEncodeForInpaint: ...@@ -240,8 +241,14 @@ class VAEEncodeForInpaint:
mask = mask[:,:,:x,:y] mask = mask[:,:,:x,:y]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6)) if grow_mask_by == 0:
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) mask_erosion = mask
else:
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
padding = math.ceil((grow_mask_by - 1) / 2)
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
m = (1.0 - mask.round()).squeeze(1) m = (1.0 - mask.round()).squeeze(1)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment