Unverified Commit 5c404f20 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP] masked_latent_inputs for inpainting pipeline (#4819)



* add

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent d8b6f5d0
...@@ -872,7 +872,11 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -872,7 +872,11 @@ class StableDiffusionControlNetInpaintPipeline(
if return_image_latents or (latents is None and not is_strength_max): if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
if latents is None: if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
...@@ -907,7 +911,11 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -907,7 +911,11 @@ class StableDiffusionControlNetInpaintPipeline(
mask = mask.to(device=device, dtype=dtype) mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
......
...@@ -293,7 +293,11 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -293,7 +293,11 @@ class PaintByExamplePipeline(DiffusionPipeline):
mask = mask.to(device=device, dtype=dtype) mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
......
...@@ -554,7 +554,7 @@ class StableDiffusionInpaintPipeline( ...@@ -554,7 +554,7 @@ class StableDiffusionInpaintPipeline(
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
...@@ -622,7 +622,11 @@ class StableDiffusionInpaintPipeline( ...@@ -622,7 +622,11 @@ class StableDiffusionInpaintPipeline(
if return_image_latents or (latents is None and not is_strength_max): if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
if latents is None: if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
...@@ -670,7 +674,11 @@ class StableDiffusionInpaintPipeline( ...@@ -670,7 +674,11 @@ class StableDiffusionInpaintPipeline(
mask = mask.to(device=device, dtype=dtype) mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
...@@ -715,6 +723,7 @@ class StableDiffusionInpaintPipeline( ...@@ -715,6 +723,7 @@ class StableDiffusionInpaintPipeline(
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
mask_image: PipelineImageInput = None, mask_image: PipelineImageInput = None,
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 1.0, strength: float = 1.0,
...@@ -914,12 +923,6 @@ class StableDiffusionInpaintPipeline( ...@@ -914,12 +923,6 @@ class StableDiffusionInpaintPipeline(
init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32) init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
mask_condition = mask.clone()
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels num_channels_unet = self.unet.config.in_channels
...@@ -947,8 +950,15 @@ class StableDiffusionInpaintPipeline( ...@@ -947,8 +950,15 @@ class StableDiffusionInpaintPipeline(
latents, noise = latents_outputs latents, noise = latents_outputs
# 7. Prepare mask latent variables # 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents
mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image_latents = self.prepare_mask_latents(
mask, mask_condition,
masked_image, masked_image,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
height, height,
......
...@@ -762,10 +762,16 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -762,10 +762,16 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = None if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None: if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=dtype) if masked_image_latents is None:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator) masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size: if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0: if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError( raise ValueError(
...@@ -890,6 +896,7 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -890,6 +896,7 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
mask_image: PipelineImageInput = None, mask_image: PipelineImageInput = None,
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.9999, strength: float = 0.9999,
...@@ -1152,7 +1159,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -1152,7 +1159,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
mask = self.mask_processor.preprocess(mask_image, height=height, width=width) mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
if init_image.shape[1] == 4: if masked_image_latents is not None:
masked_image = masked_image_latents
elif init_image.shape[1] == 4:
# if images are in latent space, we can't mask it # if images are in latent space, we can't mask it
masked_image = None masked_image = None
else: else:
......
...@@ -252,6 +252,43 @@ class StableDiffusionInpaintPipelineFastTests( ...@@ -252,6 +252,43 @@ class StableDiffusionInpaintPipelineFastTests(
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
sd_pipe(**inputs).images sd_pipe(**inputs).images
def test_stable_diffusion_inpaint_mask_latents(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(device)
sd_pipe.set_progress_bar_config(disable=None)
# normal mask + normal image
## `image`: pil, `mask_image``: pil, `masked_image_latents``: None
inputs = self.get_dummy_inputs(device)
inputs["strength"] = 0.9
out_0 = sd_pipe(**inputs).images
# image latents + mask latents
inputs = self.get_dummy_inputs(device)
image = sd_pipe.image_processor.preprocess(inputs["image"]).to(sd_pipe.device)
mask = sd_pipe.mask_processor.preprocess(inputs["mask_image"]).to(sd_pipe.device)
masked_image = image * (mask < 0.5)
generator = torch.Generator(device=device).manual_seed(0)
image_latents = (
sd_pipe.vae.encode(image).latent_dist.sample(generator=generator) * sd_pipe.vae.config.scaling_factor
)
torch.randn((1, 4, 32, 32), generator=generator)
mask_latents = (
sd_pipe.vae.encode(masked_image).latent_dist.sample(generator=generator)
* sd_pipe.vae.config.scaling_factor
)
inputs["image"] = image_latents
inputs["masked_image_latents"] = mask_latents
inputs["mask_image"] = mask
inputs["strength"] = 0.9
generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
inputs["generator"] = generator
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline pipeline_class = StableDiffusionInpaintPipeline
......
...@@ -499,3 +499,35 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -499,3 +499,35 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max() np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
> 1e-4 > 1e-4
) )
def test_stable_diffusion_xl_inpaint_mask_latents(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(device)
sd_pipe.set_progress_bar_config(disable=None)
# normal mask + normal image
## `image`: pil, `mask_image``: pil, `masked_image_latents``: None
inputs = self.get_dummy_inputs(device)
inputs["strength"] = 0.9
out_0 = sd_pipe(**inputs).images
# image latents + mask latents
inputs = self.get_dummy_inputs(device)
image = sd_pipe.image_processor.preprocess(inputs["image"]).to(sd_pipe.device)
mask = sd_pipe.mask_processor.preprocess(inputs["mask_image"]).to(sd_pipe.device)
masked_image = image * (mask < 0.5)
generator = torch.Generator(device=device).manual_seed(0)
image_latents = sd_pipe._encode_vae_image(image, generator=generator)
torch.randn((1, 4, 32, 32), generator=generator)
mask_latents = sd_pipe._encode_vae_image(masked_image, generator=generator)
inputs["image"] = image_latents
inputs["masked_image_latents"] = mask_latents
inputs["mask_image"] = mask
inputs["strength"] = 0.9
generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
inputs["generator"] = generator
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment