"...text-generation-inference.git" did not exist on "6f88bd9390a3edce1dfec025a526d6c2849effa4"
Unverified Commit aed30dff authored by apolinário's avatar apolinário Committed by GitHub
Browse files

Allow passing different prompts to each `text_encoder` on `stable_diffusion_xl` pipelines (#4156)



* sdxl prompt2

* Improve checks

* doc linting

* whoops

* remove cat

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Add other pipelines and tests

* Add multi-prompting to docs

* doc and copies check

* Fix copied froms

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Bring back the original code for unrelated files

* Fix tests

* Fix img2img

* Fix all

* fix

---------
Co-authored-by: default avatarmultimodalart <joaopaulo.passos+multimodal@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e2bbaa4f
...@@ -21,6 +21,7 @@ The abstract of the paper is the following: ...@@ -21,6 +21,7 @@ The abstract of the paper is the following:
## Tips ## Tips
- Stable Diffusion XL works especially well with images between 768 and 1024. - Stable Diffusion XL works especially well with images between 768 and 1024.
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below. - Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
### Available checkpoints: ### Available checkpoints:
...@@ -362,3 +363,25 @@ pip install xformers ...@@ -362,3 +363,25 @@ pip install xformers
[[autodoc]] StableDiffusionXLInpaintPipeline [[autodoc]] StableDiffusionXLInpaintPipeline
- all - all
- __call__ - __call__
### Passing different prompts to each text-encoder
Stable Diffusion XL was trained on two text encoders. The default behavior is to pass the same prompt to each. But it is possible to pass a different prompt for each text-encoder, as [some users](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201) noted that it can boost quality.
To do so, you can pass `prompt_2` and `negative_prompt_2` in addition to `prompt` and `negative_prompt`. By doing that, you will pass the original prompts and negative prompts (as in `prompt` and `negative_prompt`) to `text_encoder` (in official SDXL 0.9/1.0 that is [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)),
and `prompt_2` and `negative_prompt_2` to `text_encoder_2` (in official SDXL 0.9/1.0 that is [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
```py
from diffusers import StableDiffusionXLPipeline
import torch
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.to("cuda")
# prompt will be passed to OAI CLIP-ViT/L-14
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# prompt_2 will be passed to OpenCLIP-ViT/bigG-14
prompt_2 = "monet painting"
image = pipe(prompt=prompt, prompt_2=prompt_2).images[0]
```
...@@ -196,11 +196,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -196,11 +196,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt, prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt=None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -211,8 +213,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -211,8 +213,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
...@@ -223,6 +228,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -223,6 +228,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -261,9 +269,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -261,9 +269,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer) prompt = self.maybe_convert_prompt(prompt, tokenizer)
...@@ -274,8 +284,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -274,8 +284,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids text_input_ids, untruncated_ids
...@@ -311,6 +323,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -311,6 +323,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None: elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
...@@ -318,7 +332,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -318,7 +332,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -326,17 +340,16 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -326,17 +340,16 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
" the batch size of `prompt`." " the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = [] negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = tokenizer( uncond_input = tokenizer(
uncond_tokens, negative_prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
...@@ -401,9 +414,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -401,9 +414,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2,
image, image,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
...@@ -423,18 +438,30 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -423,18 +438,30 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two." " only forward one of the two."
) )
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None: elif prompt is None and prompt_embeds is None:
raise ValueError( raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
) )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None: if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape: if prompt_embeds.shape != negative_prompt_embeds.shape:
...@@ -610,6 +637,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -610,6 +637,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ image: Union[
torch.FloatTensor, torch.FloatTensor,
PIL.Image.Image, PIL.Image.Image,
...@@ -623,6 +651,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -623,6 +651,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -649,6 +678,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -649,6 +678,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
...@@ -674,6 +706,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -674,6 +706,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -749,9 +784,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -749,9 +784,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2,
image, image,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
controlnet_conditioning_scale, controlnet_conditioning_scale,
...@@ -791,10 +828,12 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -791,10 +828,12 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt, prompt,
prompt_2,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt, negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
......
...@@ -211,11 +211,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -211,11 +211,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def encode_prompt( def encode_prompt(
self, self,
prompt, prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt=None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -226,8 +228,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -226,8 +228,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
...@@ -238,6 +243,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -238,6 +243,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -276,9 +284,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -276,9 +284,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer) prompt = self.maybe_convert_prompt(prompt, tokenizer)
...@@ -289,8 +299,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -289,8 +299,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids text_input_ids, untruncated_ids
...@@ -326,6 +338,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -326,6 +338,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None: elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
...@@ -333,7 +347,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -333,7 +347,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -341,17 +355,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -341,17 +355,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
" the batch size of `prompt`." " the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = [] negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = tokenizer( uncond_input = tokenizer(
uncond_tokens, negative_prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
...@@ -416,10 +429,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -416,10 +429,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2,
height, height,
width, width,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
...@@ -441,18 +456,30 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -441,18 +456,30 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two." " only forward one of the two."
) )
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None: elif prompt is None and prompt_embeds is None:
raise ValueError( raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
) )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None: if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape: if prompt_embeds.shape != negative_prompt_embeds.shape:
...@@ -531,12 +558,14 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -531,12 +558,14 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -562,6 +591,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -562,6 +591,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
...@@ -587,6 +619,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -587,6 +619,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -660,10 +695,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -660,10 +695,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2,
height, height,
width, width,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
...@@ -695,11 +732,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -695,11 +732,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt, prompt=prompt,
device, prompt_2=prompt_2,
num_images_per_prompt, device=device,
do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt,
negative_prompt, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
......
...@@ -219,11 +219,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -219,11 +219,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt, prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt=None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -234,8 +236,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -234,8 +236,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
...@@ -246,6 +251,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -246,6 +251,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -284,9 +292,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -284,9 +292,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer) prompt = self.maybe_convert_prompt(prompt, tokenizer)
...@@ -297,8 +307,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -297,8 +307,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids text_input_ids, untruncated_ids
...@@ -334,6 +346,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -334,6 +346,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None: elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
...@@ -341,7 +355,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -341,7 +355,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -349,17 +363,16 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -349,17 +363,16 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
" the batch size of `prompt`." " the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = [] negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = tokenizer( uncond_input = tokenizer(
uncond_tokens, negative_prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
...@@ -424,10 +437,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -424,10 +437,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2,
strength, strength,
num_inference_steps, num_inference_steps,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
): ):
...@@ -453,18 +468,30 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -453,18 +468,30 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two." " only forward one of the two."
) )
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None: elif prompt is None and prompt_embeds is None:
raise ValueError( raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
) )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None: if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape: if prompt_embeds.shape != negative_prompt_embeds.shape:
...@@ -617,6 +644,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -617,6 +644,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ image: Union[
torch.FloatTensor, torch.FloatTensor,
PIL.Image.Image, PIL.Image.Image,
...@@ -631,6 +659,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -631,6 +659,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -658,6 +687,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -658,6 +687,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
The image(s) to modify with the pipeline. The image(s) to modify with the pipeline.
strength (`float`, *optional*, defaults to 0.3): strength (`float`, *optional*, defaults to 0.3):
...@@ -697,6 +729,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -697,6 +729,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -767,10 +802,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -767,10 +802,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2,
strength, strength,
num_inference_steps, num_inference_steps,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
) )
...@@ -800,11 +837,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -800,11 +837,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt, prompt=prompt,
device, prompt_2=prompt_2,
num_images_per_prompt, device=device,
do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt,
negative_prompt, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
......
...@@ -325,11 +325,13 @@ class StableDiffusionXLInpaintPipeline( ...@@ -325,11 +325,13 @@ class StableDiffusionXLInpaintPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt, prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt=None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -340,8 +342,11 @@ class StableDiffusionXLInpaintPipeline( ...@@ -340,8 +342,11 @@ class StableDiffusionXLInpaintPipeline(
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
...@@ -352,6 +357,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -352,6 +357,9 @@ class StableDiffusionXLInpaintPipeline(
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -390,9 +398,11 @@ class StableDiffusionXLInpaintPipeline( ...@@ -390,9 +398,11 @@ class StableDiffusionXLInpaintPipeline(
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer) prompt = self.maybe_convert_prompt(prompt, tokenizer)
...@@ -403,8 +413,10 @@ class StableDiffusionXLInpaintPipeline( ...@@ -403,8 +413,10 @@ class StableDiffusionXLInpaintPipeline(
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids text_input_ids, untruncated_ids
...@@ -440,6 +452,8 @@ class StableDiffusionXLInpaintPipeline( ...@@ -440,6 +452,8 @@ class StableDiffusionXLInpaintPipeline(
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None: elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
...@@ -447,7 +461,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -447,7 +461,7 @@ class StableDiffusionXLInpaintPipeline(
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -455,17 +469,16 @@ class StableDiffusionXLInpaintPipeline( ...@@ -455,17 +469,16 @@ class StableDiffusionXLInpaintPipeline(
" the batch size of `prompt`." " the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = [] negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = tokenizer( uncond_input = tokenizer(
uncond_tokens, negative_prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
...@@ -527,15 +540,16 @@ class StableDiffusionXLInpaintPipeline( ...@@ -527,15 +540,16 @@ class StableDiffusionXLInpaintPipeline(
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2,
height, height,
width, width,
strength, strength,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
): ):
...@@ -558,18 +572,30 @@ class StableDiffusionXLInpaintPipeline( ...@@ -558,18 +572,30 @@ class StableDiffusionXLInpaintPipeline(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two." " only forward one of the two."
) )
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None: elif prompt is None and prompt_embeds is None:
raise ValueError( raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
) )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None: if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape: if prompt_embeds.shape != negative_prompt_embeds.shape:
...@@ -785,6 +811,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -785,6 +811,7 @@ class StableDiffusionXLInpaintPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -795,6 +822,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -795,6 +822,7 @@ class StableDiffusionXLInpaintPipeline(
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -822,6 +850,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -822,6 +850,9 @@ class StableDiffusionXLInpaintPipeline(
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`PIL.Image.Image`): image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`. be masked out with `mask_image` and repainted according to `prompt`.
...@@ -868,6 +899,13 @@ class StableDiffusionXLInpaintPipeline( ...@@ -868,6 +899,13 @@ class StableDiffusionXLInpaintPipeline(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -894,13 +932,6 @@ class StableDiffusionXLInpaintPipeline( ...@@ -894,13 +932,6 @@ class StableDiffusionXLInpaintPipeline(
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -962,11 +993,13 @@ class StableDiffusionXLInpaintPipeline( ...@@ -962,11 +993,13 @@ class StableDiffusionXLInpaintPipeline(
# 1. Check inputs # 1. Check inputs
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2,
height, height,
width, width,
strength, strength,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
) )
...@@ -996,11 +1029,13 @@ class StableDiffusionXLInpaintPipeline( ...@@ -996,11 +1029,13 @@ class StableDiffusionXLInpaintPipeline(
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt, prompt=prompt,
device, prompt_2=prompt_2,
num_images_per_prompt, device=device,
do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt,
negative_prompt, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np
import torch import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
...@@ -177,3 +178,56 @@ class ControlNetPipelineSDXLFastTests( ...@@ -177,3 +178,56 @@ class ControlNetPipelineSDXLFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
...@@ -355,3 +355,56 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -355,3 +355,56 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
HeunDiscreteScheduler, HeunDiscreteScheduler,
]: ]:
assert_run_mixture(steps, split_1, split_2, scheduler_cls) assert_run_mixture(steps, split_1, split_2, scheduler_cls)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
...@@ -113,8 +113,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -113,8 +113,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
} }
return components return components
...@@ -132,7 +130,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -132,7 +130,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 5.0, "guidance_scale": 5.0,
"output_type": "numpy", "output_type": "numpy",
"strength": 0.75, "strength": 0.8,
} }
return inputs return inputs
...@@ -231,3 +229,62 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -231,3 +229,62 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
...@@ -123,7 +123,10 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -123,7 +123,10 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0] image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
...@@ -152,7 +155,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -152,7 +155,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4924, 0.4966, 0.4100, 0.5233, 0.5322, 0.4532, 0.5804, 0.5876, 0.4150]) expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -367,3 +370,62 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -367,3 +370,62 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
HeunDiscreteScheduler, HeunDiscreteScheduler,
]: ]:
assert_run_mixture(steps, split_1, split_2, scheduler_cls) assert_run_mixture(steps, split_1, split_2, scheduler_cls)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment