You need to sign in or sign up before continuing.
Commit ee030d28 authored by Jacob Segal's avatar Jacob Segal
Browse files

Add support for multiple unique inpainting masks

This enables workflows like "Inpaint at full resolution" when using
batch sizes greater than 1.
parent 6908f9c9
...@@ -171,24 +171,28 @@ class VAEEncodeForInpaint: ...@@ -171,24 +171,28 @@ class VAEEncodeForInpaint:
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 64) * 64
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] if len(mask.shape) < 3:
mask = mask.unsqueeze(0).unsqueeze(0)
elif len(mask.shape) < 4:
mask = mask.unsqueeze(1)
mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y] mask = mask[:,:x,:y,:]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6)) kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1)
m = (1.0 - mask.round()) m = (1.0 - mask.round()).squeeze(1)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5 pixels[:,:,:,i] += 0.5
t = vae.encode(pixels) t = vae.encode(pixels)
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, )
class CheckpointLoader: class CheckpointLoader:
@classmethod @classmethod
...@@ -759,10 +763,15 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -759,10 +763,15 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent['noise_mask'] noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") if len(noise_mask.shape) < 3:
noise_mask = noise_mask.unsqueeze(0).unsqueeze(0)
elif len(noise_mask.shape) < 4:
noise_mask = noise_mask.unsqueeze(1)
noise_mask = torch.nn.functional.interpolate(noise_mask, size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round() noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0]) if noise_mask.shape[0] < latent_image.shape[0]:
noise_mask = noise_mask.repeat(latent_image.shape[0] // noise_mask.shape[0], 1, 1, 1)
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
real_model = None real_model = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment