Unverified Commit 9800cc5e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[InstructPix2Pix] Fix pipeline implementation and add docs (#4844)

* initial evident fixes.

* instructpix2pix fixes.

* add: entry to doc.

* address PR feedback.

* make fix-copies
parent 541bb6ee
...@@ -35,4 +35,12 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le ...@@ -35,4 +35,12 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
- save_lora_weights - save_lora_weights
## StableDiffusionPipelineOutput ## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput [[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
\ No newline at end of file
## StableDiffusionXLInstructPix2PixPipeline
[[autodoc]] StableDiffusionXLInstructPix2PixPipeline
- __call__
- all
## StableDiffusionXLPipelineOutput
[[autodoc]] pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput
\ No newline at end of file
...@@ -28,6 +28,7 @@ from ...models.attention_processor import ( ...@@ -28,6 +28,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -36,6 +37,7 @@ from ...utils import ( ...@@ -36,6 +37,7 @@ from ...utils import (
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
randn_tensor, randn_tensor,
replace_example_docstring,
) )
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput from . import StableDiffusionXLPipelineOutput
...@@ -47,6 +49,36 @@ if is_invisible_watermark_available(): ...@@ -47,6 +49,36 @@ if is_invisible_watermark_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionXLInstructPix2PixPipeline
>>> from diffusers.utils import load_image
>>> resolution = 768
>>> image = load_image(
... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
... ).resize((resolution, resolution))
>>> edit_instruction = "Turn sky into a cloudy one"
>>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16
... ).to("cuda")
>>> edited_image = pipe(
... prompt=edit_instruction,
... image=image,
... height=resolution,
... width=resolution,
... guidance_scale=3.0,
... image_guidance_scale=1.5,
... num_inference_steps=30,
... ).images[0]
>>> edited_image
```
"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" """
...@@ -121,7 +153,6 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -121,7 +153,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None, add_watermarker: Optional[bool] = None,
): ):
...@@ -137,11 +168,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -137,11 +168,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
scheduler=scheduler, scheduler=scheduler,
) )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
...@@ -213,13 +242,16 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -213,13 +242,16 @@ class StableDiffusionXLInstructPix2PixPipeline(
# We'll offload the last model manually. # We'll offload the last model manually.
self.final_offload_hook = hook self.final_offload_hook = hook
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt, prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt=None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
...@@ -230,8 +262,11 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -230,8 +262,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
...@@ -242,6 +277,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -242,6 +277,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -266,6 +304,10 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -266,6 +304,10 @@ class StableDiffusionXLInstructPix2PixPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -280,9 +322,11 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -280,9 +322,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer) prompt = self.maybe_convert_prompt(prompt, tokenizer)
...@@ -293,6 +337,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -293,6 +337,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
...@@ -314,11 +359,6 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -314,11 +359,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
...@@ -330,6 +370,8 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -330,6 +370,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None: elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
...@@ -337,7 +379,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -337,7 +379,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -345,17 +387,16 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -345,17 +387,16 @@ class StableDiffusionXLInstructPix2PixPipeline(
" the batch size of `prompt`." " the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = [] negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = tokenizer( uncond_input = tokenizer(
uncond_tokens, negative_prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
...@@ -370,32 +411,30 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -370,32 +411,30 @@ class StableDiffusionXLInstructPix2PixPipeline(
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
bs_embed = pooled_prompt_embeds.shape[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( if do_classifier_free_guidance:
bs_embed * num_images_per_prompt, -1 negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
) bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
...@@ -417,15 +456,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -417,15 +456,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def get_timesteps(self, num_inference_steps, strength, device): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start
def check_inputs( def check_inputs(
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
...@@ -463,6 +494,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -463,6 +494,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
...@@ -496,9 +528,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -496,9 +528,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
image_latents = image image_latents = image
else: else:
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
image = image.float() self.upcast_vae()
self.vae.to(dtype=torch.float32) image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
...@@ -536,45 +568,24 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -536,45 +568,24 @@ class StableDiffusionXLInstructPix2PixPipeline(
return image_latents return image_latents
def _get_add_time_ids( # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
): add_time_ids = list(original_size + crops_coords_top_left + target_size)
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = ( passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
) )
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if ( if expected_add_embed_dim != passed_add_embed_dim:
expected_add_embed_dim > passed_add_embed_dim
and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
):
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
)
elif (
expected_add_embed_dim < passed_add_embed_dim
and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
):
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
)
elif expected_add_embed_dim != passed_add_embed_dim:
raise ValueError( raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
) )
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids
return add_time_ids, add_neg_time_ids
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self): def upcast_vae(self):
dtype = self.vae.dtype dtype = self.vae.dtype
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
...@@ -595,14 +606,20 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -595,14 +606,20 @@ class StableDiffusionXLInstructPix2PixPipeline(
self.vae.decoder.mid_block.to(dtype) self.vae.decoder.mid_block.to(dtype)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 100, num_inference_steps: int = 100,
guidance_scale: float = 7.5, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
image_guidance_scale: float = 1.5, image_guidance_scale: float = 1.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -620,8 +637,6 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -620,8 +637,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -630,12 +645,26 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -630,12 +645,26 @@ class StableDiffusionXLInstructPix2PixPipeline(
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
The image(s) to modify with the pipeline. The image(s) to modify with the pipeline.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
...@@ -650,6 +679,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -650,6 +679,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -698,25 +730,34 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -698,25 +730,34 @@ class StableDiffusionXLInstructPix2PixPipeline(
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR. Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
TODO If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
TODO `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
TODO For most cases, `target_size` should be set to the desired height and width of the generated image. If
aesthetic_score (`float`, *optional*, defaults to 6.0): not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
TODO section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
TDOO
Examples: Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second `tuple`. When returning a tuple, the first element is a list with the generated images.
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
...@@ -750,11 +791,13 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -750,11 +791,13 @@ class StableDiffusionXLInstructPix2PixPipeline(
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt, prompt=prompt,
device, prompt_2=prompt_2,
num_images_per_prompt, device=device,
do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt,
negative_prompt, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
...@@ -780,10 +823,6 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -780,10 +823,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
generator, generator,
) )
height, width = image_latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
# 7. Prepare latent variables # 7. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents( latents = self.prepare_latents(
...@@ -811,47 +850,40 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -811,47 +850,40 @@ class StableDiffusionXLInstructPix2PixPipeline(
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 10. Prepare added time ids & embeddings # 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
add_time_ids, add_neg_time_ids = self._get_add_time_ids( add_time_ids = self._get_add_time_ids(
original_size, original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
original_prompt_embeds_len = len(prompt_embeds)
original_add_text_embeds_len = len(add_text_embeds)
original_add_time_ids = len(add_time_ids)
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) # The extra concat similar to how it's done in SD InstructPix2Pix.
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0) prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_text_embeds = torch.cat(
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0) [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0
)
# Make dimensions consistent add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0)
add_text_embeds = torch.concat((add_text_embeds, add_text_embeds[:original_add_text_embeds_len]), dim=0)
add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[:original_add_time_ids]), dim=0)
prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[:original_prompt_embeds_len]), dim=0)
prompt_embeds = prompt_embeds.to(device).to(torch.float32) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 11. Denoising loop # 11. Denoising loop
self.unet = self.unet.to(torch.float32) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance. # Expand the latents if we are doing classifier free guidance.
# The latents are expanded 3 times because for pix2pix the guidance\ # The latents are expanded 3 times because for pix2pix the guidance
# is applied for both the text and the input image. # is applied for both the text and the input image.
latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
......
...@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
addition_embed_type="text_time", addition_embed_type="text_time",
addition_time_embed_dim=8, addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=72, # 5 * 8 + 32 projection_class_embeddings_input_dim=80, # 5 * 8 + 32
cross_attention_dim=64, cross_attention_dim=64,
) )
...@@ -118,12 +118,11 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -118,12 +118,11 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
} }
return components return components
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5 image = image / 2 + 0.5
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
...@@ -142,7 +141,6 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -142,7 +141,6 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
def test_components_function(self): def test_components_function(self):
init_components = self.get_dummy_components() init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components) pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components")) self.assertTrue(hasattr(pipe, "components"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment