Unverified Commit b562b661 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071)

* add text embeds to sd

* add text embeds to sd

* finish tests

* finish

* finish

* make style

* fix tests

* make style

* make style

* up

* better docs

* fix

* fix

* new try

* up

* up

* finish
parent c1184918
......@@ -220,12 +220,21 @@ class AltDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -233,47 +242,67 @@ class AltDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -293,7 +322,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -307,23 +336,27 @@ class AltDiffusionPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
......@@ -360,10 +393,16 @@ class AltDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -375,6 +414,32 @@ class AltDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
......@@ -396,7 +461,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -406,6 +471,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -415,8 +482,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
......@@ -431,8 +499,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -445,6 +514,13 @@ class AltDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -472,10 +548,18 @@ class AltDiffusionPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......@@ -483,8 +567,14 @@ class AltDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
......@@ -498,7 +588,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -516,7 +606,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -536,7 +626,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
if output_type == "pil":
......
......@@ -242,12 +242,21 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -255,47 +264,67 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -315,7 +344,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -329,23 +358,27 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
......@@ -382,7 +415,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps):
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
......@@ -397,6 +432,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
......@@ -462,7 +523,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
......@@ -471,6 +532,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -481,8 +544,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
......@@ -502,8 +566,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -512,6 +577,13 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -537,8 +609,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
......@@ -549,8 +621,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Preprocess image
......@@ -563,7 +641,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
# 6. Prepare latent variables
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
......@@ -578,7 +656,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -598,7 +676,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL
if output_type == "pil":
......
......@@ -129,11 +129,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
)
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
......@@ -144,7 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
)
if latents is None:
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
......@@ -163,13 +163,13 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
context = prompt_embeds
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
context = torch.cat([negative_prompt_embeds, prompt_embeds])
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
......
......@@ -364,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True)
image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
......@@ -372,13 +372,13 @@ class PaintByExamplePipeline(DiffusionPipeline):
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......
......@@ -261,12 +261,21 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -274,47 +283,67 @@ class CycleDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -334,7 +363,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -348,26 +377,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps):
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
......@@ -382,6 +417,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
......@@ -492,6 +553,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -533,6 +595,13 @@ class CycleDiffusionPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -569,8 +638,14 @@ class CycleDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
source_text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
)
source_prompt_embeds = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
)
......@@ -584,7 +659,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# 6. Prepare latent variables
latents, clean_latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
source_latents = latents
......@@ -612,17 +687,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
],
dim=0,
)
concat_text_embeddings = torch.stack(
concat_prompt_embeds = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
source_prompt_embeds[0],
prompt_embeds[0],
source_prompt_embeds[1],
prompt_embeds[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds
).sample
# perform guidance
......@@ -662,7 +737,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL
if output_type == "pil":
......
......@@ -196,7 +196,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
......@@ -210,8 +210,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = (
batch_size,
......
......@@ -192,7 +192,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
......@@ -206,8 +206,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = (
batch_size,
......
......@@ -224,7 +224,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
......@@ -238,8 +238,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = (
batch_size,
......
......@@ -117,7 +117,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
......@@ -147,8 +147,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -179,15 +179,15 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def __call__(
self,
......@@ -232,12 +232,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
......@@ -271,7 +271,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0]
# perform guidance
......
......@@ -167,7 +167,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
......@@ -197,8 +197,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -229,15 +229,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def __call__(
self,
......@@ -345,11 +345,11 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
latents_dtype = text_embeddings.dtype
latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=image)[0]
......@@ -417,9 +417,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
0
]
# perform guidance
if do_classifier_free_guidance:
......
......@@ -168,7 +168,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
......@@ -198,8 +198,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -230,15 +230,15 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
@torch.no_grad()
def __call__(
......@@ -351,13 +351,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
latents_dtype = prompt_embeds.dtype
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
else:
......@@ -424,9 +424,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
0
]
# perform guidance
if do_classifier_free_guidance:
......
......@@ -153,7 +153,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
......@@ -183,8 +183,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -215,15 +215,15 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def __call__(
self,
......@@ -338,11 +338,11 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
latents_dtype = text_embeddings.dtype
latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents
......@@ -399,7 +399,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds
)[0]
# perform guidance
......
......@@ -217,12 +217,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -230,47 +239,67 @@ class StableDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -290,7 +319,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -304,23 +333,27 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
......@@ -357,10 +390,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -372,6 +411,32 @@ class StableDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
......@@ -393,7 +458,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -403,6 +468,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -412,8 +479,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
......@@ -428,8 +496,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -442,6 +511,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -469,10 +545,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......@@ -480,8 +564,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
......@@ -495,7 +585,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -513,7 +603,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -533,7 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
if output_type == "pil":
......
......@@ -159,12 +159,21 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -172,47 +181,67 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -232,7 +261,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -246,23 +275,27 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......@@ -302,12 +335,15 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
......@@ -317,6 +353,32 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
......@@ -424,8 +486,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
......@@ -434,6 +496,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -443,8 +507,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
......@@ -464,8 +529,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -474,6 +540,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -520,6 +593,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
......@@ -529,8 +605,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare depth mask
......@@ -539,7 +621,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
depth_map,
batch_size * num_images_per_prompt,
do_classifier_free_guidance,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
)
......@@ -553,7 +635,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# 7. Prepare latent variables
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
......@@ -569,7 +651,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......
......@@ -173,12 +173,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings)
negative_prompt_embeds = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......
......@@ -249,12 +249,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -262,47 +271,67 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -322,7 +351,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -336,23 +365,27 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......@@ -392,7 +425,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps):
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
......@@ -407,6 +442,32 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
......@@ -472,7 +533,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
......@@ -481,6 +542,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -491,8 +554,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
......@@ -512,8 +576,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -522,6 +587,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -547,8 +619,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
......@@ -559,8 +631,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Preprocess image
......@@ -573,7 +651,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# 6. Prepare latent variables
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
......@@ -588,7 +666,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -608,7 +686,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL
if output_type == "pil":
......
......@@ -297,12 +297,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -310,47 +319,67 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -370,7 +399,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -384,23 +413,27 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......@@ -441,10 +474,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -456,6 +495,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......@@ -528,9 +593,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -540,6 +605,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
nrompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -549,8 +616,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
......@@ -573,8 +641,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -587,6 +656,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -645,6 +721,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask_image is None:
raise ValueError("`mask_image` input cannot be undefined.")
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
......@@ -654,7 +736,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
......@@ -672,7 +754,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -685,7 +767,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
batch_size * num_images_per_prompt,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
......@@ -718,7 +800,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -738,7 +820,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 13. Convert to PIL
if output_type == "pil":
......
......@@ -216,12 +216,21 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -229,47 +238,67 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -289,7 +318,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -303,23 +332,27 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......@@ -360,7 +393,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps):
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
......@@ -375,6 +410,32 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
......@@ -415,6 +476,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
add_predicted_noise: Optional[bool] = False,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -425,8 +488,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
......@@ -450,8 +514,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
add_predicted_noise (`bool`, *optional*, defaults to True):
......@@ -463,6 +528,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -499,8 +571,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Preprocess image and mask
......@@ -518,7 +596,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# 6. Prepare latent variables
# encode the init image into latents and scale the latents
latents, init_latents_orig, noise = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare mask latent
......@@ -537,7 +615,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -566,7 +644,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
image = self.decode_latents(latents)
# 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 12. Convert to PIL
if output_type == "pil":
......
......@@ -127,8 +127,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
num_inference_steps: int = 100,
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
......@@ -137,6 +137,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -146,8 +148,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`.
num_inference_steps (`int`, *optional*, defaults to 100):
......@@ -165,8 +168,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
generate images that are closely linked to the source image `image`, usually at the expense of lower
image quality. This pipeline requires a value of at least `1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -179,6 +183,13 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -231,6 +242,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
# 0. Check inputs
self.check_inputs(prompt, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 1. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
......@@ -242,8 +256,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
# 2. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 3. Preprocess image
......@@ -259,7 +279,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image,
batch_size,
num_images_per_prompt,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
do_classifier_free_guidance,
generator,
......@@ -272,7 +292,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -306,7 +326,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# Hack:
# For karras style schedulers the model does classifer free guidance using the
......@@ -348,7 +368,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 12. Convert to PIL
if output_type == "pil":
......@@ -398,12 +418,21 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -411,47 +440,67 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -471,7 +520,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -485,23 +534,28 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([text_embeddings, uncond_embeddings, uncond_embeddings])
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......
......@@ -162,12 +162,21 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -175,47 +184,67 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -235,7 +264,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -249,23 +278,27 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
......@@ -317,7 +350,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -327,6 +360,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -336,8 +371,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
......@@ -352,8 +388,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -366,6 +403,13 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -404,14 +448,20 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
raise ValueError("has to use guidance_scale")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(text_embeddings.dtype)
sigmas = sigmas.to(prompt_embeds.dtype)
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
......@@ -420,7 +470,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -434,7 +484,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
......@@ -447,7 +497,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
if output_type == "pil":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment