Unverified Commit 3becd368 authored by hwuebben's avatar hwuebben Committed by GitHub
Browse files

Update pipeline_stable_diffusion_inpaint_legacy.py (#2903)



* Update pipeline_stable_diffusion_inpaint_legacy.py

* fix preprocessing of Pil images with adequate batch size

* revert map

* add tests

* reformat

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* next try to fix the style

* wth is this

* Update testing_utils.py

* Update testing_utils.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c8fdfe45
...@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def preprocess_image(image): def preprocess_image(image, batch_size):
w, h = image.size w, h = image.size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.0 * image - 1.0 return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8): def preprocess_mask(mask, batch_size, scale_factor=8):
if not isinstance(mask, torch.FloatTensor): if not isinstance(mask, torch.FloatTensor):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
...@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8): ...@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = np.vstack([mask[None]] * batch_size)
mask = 1 - mask # repaint white, keep black mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask) mask = torch.from_numpy(mask)
return mask return mask
...@@ -521,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -521,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
image = image.to(device=self.device, dtype=dtype) image = image.to(device=self.device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_latent_dist.sample(generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
# Expand init_latents for batch_size and num_images_per_prompt # Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
init_latents_orig = init_latents init_latents_orig = init_latents
# add noise to latents using the timesteps # add noise to latents using the timesteps
...@@ -659,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -659,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(
# 4. Preprocess image and mask # 4. Preprocess image and mask
if not isinstance(image, torch.FloatTensor): if not isinstance(image, torch.FloatTensor):
image = preprocess_image(image) image = preprocess_image(image, batch_size)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor) mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -671,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -671,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(
# 6. Prepare latent variables # 6. Prepare latent variables
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
latents, init_latents_orig, noise = self.prepare_latents( latents, init_latents_orig, noise = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
# 7. Prepare mask latent # 7. Prepare mask latent
mask = mask_image.to(device=self.device, dtype=latents.dtype) mask = mask_image.to(device=self.device, dtype=latents.dtype)
mask = torch.cat([mask] * batch_size * num_images_per_prompt) mask = torch.cat([mask] * num_images_per_prompt)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
...@@ -279,6 +279,16 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: ...@@ -279,6 +279,16 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
return image return image
def preprocess_image(image: PIL.Image, batch_size: int):
w, h = image.size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available(): if is_opencv_available():
import cv2 import cv2
......
...@@ -34,7 +34,7 @@ from diffusers import ( ...@@ -34,7 +34,7 @@ from diffusers import (
VQModel, VQModel,
) )
from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu from diffusers.utils.testing_utils import load_numpy, preprocess_image, require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -217,6 +217,55 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase): ...@@ -217,6 +217,55 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_legacy_batched(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
init_images_tens = preprocess_image(init_image, batch_size=2)
init_masks_tens = init_images_tens + 4
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipelineLegacy(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
images = sd_pipe(
[prompt] * 2,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=init_images_tens,
mask_image=init_masks_tens,
).images
assert images.shape == (2, 32, 32, 3)
image_slice_0 = images[0, -3:, -3:, -1].flatten()
image_slice_1 = images[1, -3:, -3:, -1].flatten()
expected_slice_0 = np.array([0.4697, 0.3770, 0.4096, 0.4653, 0.4497, 0.4183, 0.3950, 0.4668, 0.4672])
expected_slice_1 = np.array([0.4105, 0.4987, 0.5771, 0.4921, 0.4237, 0.5684, 0.5496, 0.4645, 0.5272])
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-2
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-2
def test_stable_diffusion_inpaint_legacy_negative_prompt(self): def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
...@@ -349,7 +398,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): ...@@ -349,7 +398,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, generator_device="cpu", seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image( init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
...@@ -379,7 +428,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): ...@@ -379,7 +428,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device) inputs = self.get_inputs()
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten() image_slice = image[0, 253:256, 253:256, -1].flatten()
...@@ -388,6 +437,40 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): ...@@ -388,6 +437,40 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice).max() < 1e-4 assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_batched(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = preprocess_image(inputs["image"], batch_size=2)
mask = inputs["mask_image"].convert("L")
mask = np.array(mask).astype(np.float32) / 255.0
mask = torch.from_numpy(1 - mask)
masks = torch.vstack([mask[None][None]] * 2)
inputs["mask_image"] = masks
image = pipe(**inputs).images
assert image.shape == (2, 512, 512, 3)
image_slice_0 = image[0, 253:256, 253:256, -1].flatten()
image_slice_1 = image[1, 253:256, 253:256, -1].flatten()
expected_slice_0 = np.array(
[0.52093095, 0.4176447, 0.32752383, 0.6175223, 0.50563973, 0.36470804, 0.65460044, 0.5775188, 0.44332123]
)
expected_slice_1 = np.array(
[0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092]
)
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-4
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_k_lms(self): def test_stable_diffusion_inpaint_legacy_k_lms(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None "CompVis/stable-diffusion-v1-4", safety_checker=None
...@@ -397,7 +480,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): ...@@ -397,7 +480,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device) inputs = self.get_inputs()
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten() image_slice = image[0, 253:256, 253:256, -1].flatten()
...@@ -437,7 +520,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): ...@@ -437,7 +520,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device, dtype=torch.float16) inputs = self.get_inputs()
pipe(**inputs, callback=callback_fn, callback_steps=1) pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called assert callback_fn.has_been_called
assert number_of_steps == 2 assert number_of_steps == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment