Unverified Commit 9baa29e9 authored by Robin Hutmacher's avatar Robin Hutmacher Committed by GitHub
Browse files

Fix typo in StableDiffusionInpaintPipeline (#2197)



* Fix typo in StableDiffusionInpaintPipeline

* Add embedded prompt handling

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 58c416ab
...@@ -605,7 +605,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -605,7 +605,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
nrompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -719,7 +719,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -719,7 +719,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps) self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -728,7 +736,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -728,7 +736,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError("`mask_image` input cannot be undefined.") raise ValueError("`mask_image` input cannot be undefined.")
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -737,7 +751,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -737,7 +751,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess mask and image # 4. Preprocess mask and image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment