"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "66e6f824816197a390f4c9861ab3747054cbe40e"
Commit 10f2609f authored by comfyanonymous's avatar comfyanonymous
Browse files

Add InpaintModelConditioning node.

This is an alternative to VAE Encode for inpaint that should work with
lower denoise.

This is a different take on #2501
parent b4e915e7
......@@ -100,11 +100,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
......@@ -117,9 +135,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
......
......@@ -359,6 +359,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
......@@ -1628,10 +1684,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size()
new_image = torch.zeros(
new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32,
)
) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones(
......@@ -1723,6 +1780,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment