Unverified Commit 513fc681 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Stable Diffusion Inpaint] Allow tensor as input image & mask (#1527)

up
parent cc22bda5
...@@ -630,11 +630,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -630,11 +630,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
) )
# 4. Preprocess mask and image # 4. Preprocess mask and image
if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image): mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
else:
mask = mask_image
masked_image = image * (mask < 0.5)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -218,6 +218,64 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -218,6 +218,64 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_image_tensor(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet_inpaint
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.repeat(1, 1, 2, 2)
mask_image = image / 2
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=image,
mask_image=mask_image[:, 0],
)
out_1 = output.images
image = image.cpu().permute(0, 2, 3, 1)[0]
mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]
image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=image,
mask_image=mask_image,
)
out_2 = output.images
assert out_1.shape == (1, 64, 64, 3)
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self): def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
device = "cpu" device = "cpu"
unet = self.dummy_cond_unet_inpaint unet = self.dummy_cond_unet_inpaint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment