Unverified Commit 06924c6a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439)

* accept tensors

* fix mask handling

* make device placement cleaner

* update doc for mask image
parent 761f0297
...@@ -145,8 +145,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -145,8 +145,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
process. This is the image whose masked region will be inpainted. process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`): mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
converted to a single channel (luminance) before use. PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified is 1, the denoising process will be run on the masked area for the full number of iterations specified
...@@ -202,10 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -202,10 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# preprocess image # preprocess image
init_image = preprocess_image(init_image).to(self.device) if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image.to(self.device)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
...@@ -215,8 +218,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -215,8 +218,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_latents_orig = init_latents init_latents_orig = init_latents
# preprocess mask # preprocess mask
mask = preprocess_mask(mask_image).to(self.device) if not isinstance(mask_image, torch.FloatTensor):
mask = torch.cat([mask] * batch_size) mask_image = preprocess_mask(mask_image)
mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
# check sizes # check sizes
if not mask.shape == init_latents.shape: if not mask.shape == init_latents.shape:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment