Unverified Commit d1efefe1 authored by 1lint's avatar 1lint Committed by GitHub
Browse files

[Breaking change] fix legacy inpaint noise and resize mask tensor (#2147)

* fix legacy inpaint noise and resize mask tensor

* updated legacy inpaint pipe test expected_slice
parent 7d96b38b
...@@ -45,16 +45,34 @@ def preprocess_image(image): ...@@ -45,16 +45,34 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8): def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size if not isinstance(mask, torch.FloatTensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 mask = mask.convert("L")
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) w, h = mask.size
mask = np.array(mask).astype(np.float32) / 255.0 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = np.tile(mask, (4, 1, 1)) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = np.array(mask).astype(np.float32) / 255.0
mask = 1 - mask # repaint white, keep black mask = np.tile(mask, (4, 1, 1))
mask = torch.from_numpy(mask) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
return mask mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
else:
valid_mask_channel_sizes = [1, 3]
# if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
if mask.shape[3] in valid_mask_channel_sizes:
mask = mask.permute(0, 3, 1, 2)
elif mask.shape[1] not in valid_mask_channel_sizes:
raise ValueError(
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
)
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
mask = mask.mean(dim=1, keepdim=True)
h, w = mask.shape[-2:]
h, w = map(lambda x: x - x % 32, (h, w)) # resize to integer multiple of 32
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
...@@ -497,8 +515,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -497,8 +515,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
mask_image (`torch.FloatTensor` or `PIL.Image.Image`): mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified is 1, the denoising process will be run on the masked area for the full number of iterations specified
...@@ -585,8 +603,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -585,8 +603,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if not isinstance(image, torch.FloatTensor): if not isinstance(image, torch.FloatTensor):
image = preprocess_image(image) image = preprocess_image(image)
if not isinstance(mask_image, torch.FloatTensor): mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -640,6 +657,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -640,6 +657,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# use original latents corresponding to unmasked portions of the image
latents = (init_latents_orig * mask) + (latents * (1 - mask))
# 10. Post-processing # 10. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
......
...@@ -212,8 +212,8 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase): ...@@ -212,8 +212,8 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153]) expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
...@@ -260,7 +260,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase): ...@@ -260,7 +260,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150]) expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment