Unverified Commit b562b661 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071)

* add text embeds to sd

* add text embeds to sd

* finish tests

* finish

* finish

* make style

* fix tests

* make style

* make style

* up

* better docs

* fix

* fix

* new try

* up

* up

* finish
parent c1184918
...@@ -220,12 +220,21 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -220,12 +220,21 @@ class AltDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -233,47 +242,67 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -233,47 +242,67 @@ class AltDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -293,7 +322,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -293,7 +322,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -307,23 +336,27 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -307,23 +336,27 @@ class AltDiffusionPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -360,10 +393,16 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -360,10 +393,16 @@ class AltDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -375,6 +414,32 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -375,6 +414,32 @@ class AltDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
...@@ -396,7 +461,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -396,7 +461,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -406,6 +471,8 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -406,6 +471,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -415,8 +482,9 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -415,8 +482,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
...@@ -431,8 +499,9 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -431,8 +499,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -445,6 +514,13 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -445,6 +514,13 @@ class AltDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -472,10 +548,18 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -472,10 +548,18 @@ class AltDiffusionPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps) self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -483,8 +567,14 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -483,8 +567,14 @@ class AltDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -498,7 +588,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -498,7 +588,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -516,7 +606,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -516,7 +606,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -536,7 +626,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -536,7 +626,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -242,12 +242,21 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -242,12 +242,21 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -255,47 +264,67 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -255,47 +264,67 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -315,7 +344,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -315,7 +344,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -329,23 +358,27 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -329,23 +358,27 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -382,7 +415,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -382,7 +415,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps): def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
...@@ -397,6 +432,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -397,6 +432,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
...@@ -462,7 +523,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -462,7 +523,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
...@@ -471,6 +532,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -471,6 +532,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -481,8 +544,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -481,8 +544,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process.
...@@ -502,8 +566,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -502,8 +566,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -512,6 +577,13 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -512,6 +577,13 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -537,8 +609,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -537,8 +609,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
image = init_image or image image = init_image or image
# 1. Check inputs # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps) self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
...@@ -549,8 +621,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -549,8 +621,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess image # 4. Preprocess image
...@@ -563,7 +641,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -563,7 +641,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -578,7 +656,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -578,7 +656,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -598,7 +676,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -598,7 +676,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 10. Run safety checker # 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL # 11. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -129,11 +129,11 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -129,11 +129,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
) )
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
...@@ -144,7 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -144,7 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype) latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
...@@ -163,13 +163,13 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -163,13 +163,13 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if guidance_scale == 1.0: if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance # guidance_scale of 1 means no guidance
latents_input = latents latents_input = latents
context = text_embeddings context = prompt_embeds
else: else:
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
latents_input = torch.cat([latents] * 2) latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings]) context = torch.cat([negative_prompt_embeds, prompt_embeds])
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
......
...@@ -364,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -364,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True) image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
# duplicate image embeddings for each generation per prompt, using mps friendly method # duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape bs_embed, seq_len, _ = image_embeddings.shape
...@@ -372,13 +372,13 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -372,13 +372,13 @@ class PaintByExamplePipeline(DiffusionPipeline):
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings return image_embeddings
......
...@@ -261,12 +261,21 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -261,12 +261,21 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -274,47 +283,67 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -274,47 +283,67 @@ class CycleDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -334,7 +363,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -334,7 +363,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -348,26 +377,32 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -348,26 +377,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps): def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
...@@ -382,6 +417,32 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -382,6 +417,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -492,6 +553,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -492,6 +553,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.1, eta: Optional[float] = 0.1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -533,6 +595,13 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -533,6 +595,13 @@ class CycleDiffusionPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -569,8 +638,14 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -569,8 +638,14 @@ class CycleDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None) prompt_embeds = self._encode_prompt(
source_text_embeddings = self._encode_prompt( prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
)
source_prompt_embeds = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
) )
...@@ -584,7 +659,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -584,7 +659,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
latents, clean_latents = self.prepare_latents( latents, clean_latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
source_latents = latents source_latents = latents
...@@ -612,17 +687,17 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -612,17 +687,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
], ],
dim=0, dim=0,
) )
concat_text_embeddings = torch.stack( concat_prompt_embeds = torch.stack(
[ [
source_text_embeddings[0], source_prompt_embeds[0],
text_embeddings[0], prompt_embeds[0],
source_text_embeddings[1], source_prompt_embeds[1],
text_embeddings[1], prompt_embeds[1],
], ],
dim=0, dim=0,
) )
concat_noise_pred = self.unet( concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds
).sample ).sample
# perform guidance # perform guidance
...@@ -662,7 +737,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -662,7 +737,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 10. Run safety checker # 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL # 11. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -196,7 +196,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -196,7 +196,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings # get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
...@@ -210,8 +210,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -210,8 +210,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
).input_ids ).input_ids
else: else:
uncond_input = neg_prompt_ids uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = ( latents_shape = (
batch_size, batch_size,
......
...@@ -192,7 +192,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -192,7 +192,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings # get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
...@@ -206,8 +206,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -206,8 +206,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
).input_ids ).input_ids
else: else:
uncond_input = neg_prompt_ids uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = ( latents_shape = (
batch_size, batch_size,
......
...@@ -224,7 +224,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -224,7 +224,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings # get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
...@@ -238,8 +238,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -238,8 +238,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
).input_ids ).input_ids
else: else:
uncond_input = neg_prompt_ids uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
latents_shape = ( latents_shape = (
batch_size, batch_size,
......
...@@ -117,7 +117,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -117,7 +117,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
...@@ -147,8 +147,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -147,8 +147,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -179,15 +179,15 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -179,15 +179,15 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="np", return_tensors="np",
) )
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def __call__( def __call__(
self, self,
...@@ -232,12 +232,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -232,12 +232,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
...@@ -271,7 +271,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -271,7 +271,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0] noise_pred = noise_pred[0]
# perform guidance # perform guidance
......
...@@ -167,7 +167,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -167,7 +167,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
...@@ -197,8 +197,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -197,8 +197,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -229,15 +229,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -229,15 +229,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="np", return_tensors="np",
) )
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def __call__( def __call__(
self, self,
...@@ -345,11 +345,11 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -345,11 +345,11 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
latents_dtype = text_embeddings.dtype latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype) image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=image)[0] init_latents = self.vae_encoder(sample=image)[0]
...@@ -417,9 +417,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -417,9 +417,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet( noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings 0
)[0] ]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -168,7 +168,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -168,7 +168,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
...@@ -198,8 +198,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -198,8 +198,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -230,15 +230,15 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -230,15 +230,15 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="np", return_tensors="np",
) )
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -351,13 +351,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -351,13 +351,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
num_channels_latents = NUM_LATENT_CHANNELS num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = prompt_embeds.dtype
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
else: else:
...@@ -424,9 +424,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -424,9 +424,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet( noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings 0
)[0] ]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -153,7 +153,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -153,7 +153,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
...@@ -183,8 +183,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -183,8 +183,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -215,15 +215,15 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -215,15 +215,15 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="np", return_tensors="np",
) )
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def __call__( def __call__(
self, self,
...@@ -338,11 +338,11 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -338,11 +338,11 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
latents_dtype = text_embeddings.dtype latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype) image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
...@@ -399,7 +399,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -399,7 +399,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds
)[0] )[0]
# perform guidance # perform guidance
......
...@@ -217,12 +217,21 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -217,12 +217,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -230,47 +239,67 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -230,47 +239,67 @@ class StableDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -290,7 +319,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -290,7 +319,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -304,23 +333,27 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -304,23 +333,27 @@ class StableDiffusionPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -357,10 +390,16 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -357,10 +390,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -372,6 +411,32 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -372,6 +411,32 @@ class StableDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
...@@ -393,7 +458,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -393,7 +458,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -403,6 +468,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -403,6 +468,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -412,8 +479,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -412,8 +479,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
...@@ -428,8 +496,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -428,8 +496,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -442,6 +511,13 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -442,6 +511,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -469,10 +545,18 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -469,10 +545,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps) self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -480,8 +564,14 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -480,8 +564,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -495,7 +585,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -495,7 +585,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -513,7 +603,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -513,7 +603,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -533,7 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -533,7 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -159,12 +159,21 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -159,12 +159,21 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -172,47 +181,67 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -172,47 +181,67 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer( if prompt_embeds is None:
prompt, text_inputs = self.tokenizer(
padding="max_length", prompt,
max_length=self.tokenizer.model_max_length, padding="max_length",
truncation=True, max_length=self.tokenizer.model_max_length,
return_tensors="pt", truncation=True,
) return_tensors="pt",
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -232,7 +261,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -232,7 +261,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -246,23 +275,27 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -246,23 +275,27 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -302,12 +335,15 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -302,12 +335,15 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
...@@ -317,6 +353,32 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -317,6 +353,32 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
...@@ -424,8 +486,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -424,8 +486,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.FloatTensor, PIL.Image.Image] = None,
depth_map: Optional[torch.FloatTensor] = None, depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
...@@ -434,6 +496,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -434,6 +496,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -443,8 +507,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -443,8 +507,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process.
...@@ -464,8 +529,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -464,8 +529,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -474,6 +540,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -474,6 +540,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -520,6 +593,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -520,6 +593,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, strength, callback_steps) self.check_inputs(prompt, strength, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
...@@ -529,8 +605,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -529,8 +605,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare depth mask # 4. Prepare depth mask
...@@ -539,7 +621,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -539,7 +621,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
depth_map, depth_map,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
) )
...@@ -553,7 +635,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -553,7 +635,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# 7. Prepare latent variables # 7. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -569,7 +651,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -569,7 +651,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -173,12 +173,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -173,12 +173,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings) negative_prompt_embeds = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings return image_embeddings
......
...@@ -249,12 +249,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -249,12 +249,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -262,47 +271,67 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -262,47 +271,67 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -322,7 +351,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -322,7 +351,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -336,23 +365,27 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -336,23 +365,27 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -392,7 +425,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -392,7 +425,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps): def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
...@@ -407,6 +442,32 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -407,6 +442,32 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
...@@ -472,7 +533,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -472,7 +533,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
...@@ -481,6 +542,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -481,6 +542,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -491,8 +554,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -491,8 +554,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process.
...@@ -512,8 +576,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -512,8 +576,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -522,6 +587,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -522,6 +587,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -547,8 +619,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -547,8 +619,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
image = init_image or image image = init_image or image
# 1. Check inputs # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps) self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
...@@ -559,8 +631,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -559,8 +631,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess image # 4. Preprocess image
...@@ -573,7 +651,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -573,7 +651,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -588,7 +666,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -588,7 +666,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -608,7 +686,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -608,7 +686,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 10. Run safety checker # 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 11. Convert to PIL # 11. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -297,12 +297,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -297,12 +297,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -310,47 +319,67 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -310,47 +319,67 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -370,7 +399,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -370,7 +399,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -384,23 +413,27 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -384,23 +413,27 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -441,10 +474,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -441,10 +474,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -456,6 +495,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -456,6 +495,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -528,9 +593,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -528,9 +593,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -540,6 +605,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -540,6 +605,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
nrompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -549,8 +616,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -549,8 +616,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`): image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`. be masked out with `mask_image` and repainted according to `prompt`.
...@@ -573,8 +641,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -573,8 +641,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -587,6 +656,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -587,6 +656,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -645,6 +721,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -645,6 +721,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps) self.check_inputs(prompt, height, width, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask_image is None:
raise ValueError("`mask_image` input cannot be undefined.")
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
...@@ -654,7 +736,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -654,7 +736,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
...@@ -672,7 +754,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -672,7 +754,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -685,7 +767,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -685,7 +767,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, do_classifier_free_guidance,
...@@ -718,7 +800,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -718,7 +800,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -738,7 +820,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -738,7 +820,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 12. Run safety checker # 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 13. Convert to PIL # 13. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -216,12 +216,21 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -216,12 +216,21 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -229,47 +238,67 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -229,47 +238,67 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -289,7 +318,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -289,7 +318,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -303,23 +332,27 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -303,23 +332,27 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -360,7 +393,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -360,7 +393,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps): def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
...@@ -375,6 +410,32 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -375,6 +410,32 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
...@@ -415,6 +476,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -415,6 +476,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
add_predicted_noise: Optional[bool] = False, add_predicted_noise: Optional[bool] = False,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -425,8 +488,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -425,8 +488,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted. process. This is the image whose masked region will be inpainted.
...@@ -450,8 +514,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -450,8 +514,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
add_predicted_noise (`bool`, *optional*, defaults to True): add_predicted_noise (`bool`, *optional*, defaults to True):
...@@ -463,6 +528,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -463,6 +528,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -499,8 +571,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -499,8 +571,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess image and mask # 4. Preprocess image and mask
...@@ -518,7 +596,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -518,7 +596,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
latents, init_latents_orig, noise = self.prepare_latents( latents, init_latents_orig, noise = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
# 7. Prepare mask latent # 7. Prepare mask latent
...@@ -537,7 +615,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -537,7 +615,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -566,7 +644,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -566,7 +644,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 11. Run safety checker # 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 12. Convert to PIL # 12. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
...@@ -127,8 +127,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -127,8 +127,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.FloatTensor, PIL.Image.Image] = None,
num_inference_steps: int = 100, num_inference_steps: int = 100,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5, image_guidance_scale: float = 1.5,
...@@ -137,6 +137,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -137,6 +137,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -146,8 +148,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -146,8 +148,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`): image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`. `Image`, or tensor representing an image batch which will be repainted according to `prompt`.
num_inference_steps (`int`, *optional*, defaults to 100): num_inference_steps (`int`, *optional*, defaults to 100):
...@@ -165,8 +168,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -165,8 +168,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
generate images that are closely linked to the source image `image`, usually at the expense of lower generate images that are closely linked to the source image `image`, usually at the expense of lower
image quality. This pipeline requires a value of at least `1`. image quality. This pipeline requires a value of at least `1`.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -179,6 +183,13 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -179,6 +183,13 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -231,6 +242,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -231,6 +242,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
# 0. Check inputs # 0. Check inputs
self.check_inputs(prompt, callback_steps) self.check_inputs(prompt, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 1. Define call parameters # 1. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
...@@ -242,8 +256,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -242,8 +256,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
# 2. Encode input prompt # 2. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 3. Preprocess image # 3. Preprocess image
...@@ -259,7 +279,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -259,7 +279,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image, image,
batch_size, batch_size,
num_images_per_prompt, num_images_per_prompt,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
do_classifier_free_guidance, do_classifier_free_guidance,
generator, generator,
...@@ -272,7 +292,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -272,7 +292,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -306,7 +326,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -306,7 +326,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# Hack: # Hack:
# For karras style schedulers the model does classifer free guidance using the # For karras style schedulers the model does classifer free guidance using the
...@@ -348,7 +368,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -348,7 +368,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 11. Run safety checker # 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 12. Convert to PIL # 12. Convert to PIL
if output_type == "pil": if output_type == "pil":
...@@ -398,12 +418,21 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -398,12 +418,21 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -411,47 +440,67 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -411,47 +440,67 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_inputs = self.tokenizer(
logger.warning( prompt,
"The following part of your input was truncated because CLIP can only handle sequences up to" padding="max_length",
f" {self.tokenizer.model_max_length} tokens: {removed_text}" max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -471,7 +520,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -471,7 +520,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -485,23 +534,28 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -485,23 +534,28 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([text_embeddings, uncond_embeddings, uncond_embeddings]) # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
......
...@@ -162,12 +162,21 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -162,12 +162,21 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -175,47 +184,67 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -175,47 +184,67 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer( if prompt_embeds is None:
prompt, text_inputs = self.tokenizer(
padding="max_length", prompt,
max_length=self.tokenizer.model_max_length, padding="max_length",
truncation=True, max_length=self.tokenizer.model_max_length,
return_tensors="pt", truncation=True,
) return_tensors="pt",
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -235,7 +264,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -235,7 +264,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -249,23 +278,27 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -249,23 +278,27 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -317,7 +350,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -317,7 +350,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -327,6 +360,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -327,6 +360,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -336,8 +371,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -336,8 +371,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
...@@ -352,8 +388,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -352,8 +388,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -366,6 +403,13 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -366,6 +403,13 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -404,14 +448,20 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -404,14 +448,20 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
raise ValueError("has to use guidance_scale") raise ValueError("has to use guidance_scale")
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device) self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
sigmas = self.scheduler.sigmas sigmas = self.scheduler.sigmas
sigmas = sigmas.to(text_embeddings.dtype) sigmas = sigmas.to(prompt_embeds.dtype)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.in_channels
...@@ -420,7 +470,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -420,7 +470,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -434,7 +484,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -434,7 +484,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
latent_model_input = torch.cat([x] * 2) latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2) t = torch.cat([t] * 2)
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
...@@ -447,7 +497,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -447,7 +497,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
image = self.decode_latents(latents) image = self.decode_latents(latents)
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil": if output_type == "pil":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment