Unverified Commit 4f3ddb6c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Paint by Example] Better default for image width (#1587)

parent 4eb9ad0d
...@@ -442,14 +442,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -442,14 +442,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet # 1. Define call parameters
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(example_image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
batch_size = 1 batch_size = 1
elif isinstance(image, list): elif isinstance(image, list):
...@@ -462,14 +455,18 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -462,14 +455,18 @@ class PaintByExamplePipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input image # 2. Preprocess mask and image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
width, height = masked_image.shape[-2:]
# 3. Check inputs
self.check_inputs(example_image, height, width, callback_steps)
# 4. Encode input image
image_embeddings = self._encode_image( image_embeddings = self._encode_image(
example_image, device, num_images_per_prompt, do_classifier_free_guidance example_image, device, num_images_per_prompt, do_classifier_free_guidance
) )
# 4. Preprocess mask and image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment