Unverified Commit e4c356d3 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix for InstructPix2PixPipeline to allow for prompt embeds to be passed in...


Fix for InstructPix2PixPipeline to allow for prompt embeds to be passed in without prompts.  (#2456)

* fix check inputs to allow prompt embeds in instruct pix2pix

* linting

* add reference comment to check inputs

* remove comment

* style changes

---------
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent 2ea1da89
...@@ -246,13 +246,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -246,13 +246,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Check inputs # 0. Check inputs
self.check_inputs(prompt, callback_steps) self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
# 1. Define call parameters # 1. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -640,10 +646,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -640,10 +646,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
def check_inputs(self, prompt, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") ):
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
): ):
...@@ -652,6 +657,32 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -652,6 +657,32 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment