"docs/source/vscode:/vscode.git/clone" did not exist on "30c977d1f5300bffdf14b92d06c90293e789b587"
Commit 10f2609f authored by comfyanonymous's avatar comfyanonymous
Browse files

Add InpaintModelConditioning node.

This is an alternative to VAE Encode for inpaint that should work with
lower denoise.

This is a different take on #2501
parent b4e915e7
...@@ -100,11 +100,29 @@ class BaseModel(torch.nn.Module): ...@@ -100,11 +100,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model: if self.inpaint_model:
concat_keys = ("mask", "masked_image") concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
latent_image = kwargs.get("latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
...@@ -117,9 +135,9 @@ class BaseModel(torch.nn.Module): ...@@ -117,9 +135,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys: for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device)) cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
......
...@@ -359,6 +359,62 @@ class VAEEncodeForInpaint: ...@@ -359,6 +359,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
...@@ -1628,10 +1684,11 @@ class ImagePadForOutpaint: ...@@ -1628,10 +1684,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering): def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size() d1, d2, d3, d4 = image.size()
new_image = torch.zeros( new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4), (d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32, dtype=torch.float32,
) ) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones( mask = torch.ones(
...@@ -1723,6 +1780,7 @@ NODE_CLASS_MAPPINGS = { ...@@ -1723,6 +1780,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader, "GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply, "GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment