Unverified Commit b562b661 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071)

* add text embeds to sd

* add text embeds to sd

* finish tests

* finish

* finish

* make style

* fix tests

* make style

* make style

* up

* better docs

* fix

* fix

* new try

* up

* up

* finish
parent c1184918
......@@ -136,12 +136,21 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -149,47 +158,67 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -209,7 +238,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -223,23 +252,27 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
......@@ -325,8 +358,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
......@@ -335,6 +368,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -344,8 +379,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
......@@ -358,8 +394,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -372,6 +409,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -422,6 +466,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
......@@ -431,13 +478,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Preprocess image
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......@@ -445,7 +498,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
......@@ -460,7 +513,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -493,7 +546,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level
).sample
# perform guidance
......
......@@ -217,7 +217,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -253,16 +253,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
prompt_embeds = prompt_embeds[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -299,16 +299,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text
if enable_safety_guidance:
......@@ -329,15 +329,15 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings])
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
......@@ -390,10 +390,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -405,6 +411,32 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......@@ -583,7 +615,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
warnings.warn("Safety checker disabled!")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
)
......@@ -598,7 +630,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -621,7 +653,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -680,7 +712,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance
image, device, prompt_embeds.dtype, enable_safety_guidance
)
# 10. Convert to PIL
......
......@@ -153,15 +153,15 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
......@@ -176,16 +176,16 @@ class UnCLIPPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
......@@ -199,12 +199,12 @@ class UnCLIPPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
return prompt_embeds, text_encoder_hidden_states, text_mask
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
......@@ -336,7 +336,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
......@@ -349,7 +349,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
prior_latents,
......@@ -363,7 +363,7 @@ class UnCLIPPipeline(DiffusionPipeline):
predicted_image_embedding = self.prior(
latent_model_input,
timestep=t,
proj_embedding=text_embeddings,
proj_embedding=prompt_embeds,
encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask,
).predicted_image_embedding
......@@ -397,7 +397,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
......
......@@ -136,10 +136,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
text_mask = text_inputs.attention_mask.bool().to(device)
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
......@@ -155,16 +155,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
......@@ -178,12 +178,12 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
return prompt_embeds, text_encoder_hidden_states, text_mask
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype
......@@ -314,7 +314,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
do_classifier_free_guidance = decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)
......@@ -323,7 +323,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
......
......@@ -52,7 +52,7 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance):
if do_classifier_free_guidance:
# Add the classifier free guidance embeddings to the image embeddings
image_embeddings_batch_size = image_embeddings.shape[0]
......@@ -63,15 +63,15 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
# The image embeddings batch size and the text embeddings batch size are equal
assert image_embeddings.shape[0] == text_embeddings.shape[0]
assert image_embeddings.shape[0] == prompt_embeds.shape[0]
batch_size = text_embeddings.shape[0]
batch_size = prompt_embeds.shape[0]
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
# adding CLIP embeddings to the existing timestep embedding, ...
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
time_projected_prompt_embeds = self.embedding_proj(prompt_embeds)
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds
# ... and by projecting CLIP embeddings into four
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
......
......@@ -187,7 +187,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -227,16 +227,16 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -255,30 +255,30 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -313,18 +313,18 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = self.image_encoder(pixel_values)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......@@ -524,9 +524,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompts
text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1)
dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1)
prompt_types = ("text", "image")
# 4. Prepare timesteps
......
......@@ -114,7 +114,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -173,18 +173,18 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = self.image_encoder(pixel_values)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......
......@@ -138,7 +138,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -181,16 +181,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -227,23 +227,23 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
......@@ -273,10 +273,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -288,6 +294,32 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......@@ -412,7 +444,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
......@@ -427,7 +459,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -443,7 +475,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......
......@@ -120,7 +120,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0]
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
......@@ -128,15 +128,15 @@ class VQDiffusionPipeline(DiffusionPipeline):
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.learnable:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings
negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
else:
uncond_tokens = [""] * batch_size
......@@ -148,21 +148,21 @@ class VQDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
@torch.no_grad()
def __call__(
......@@ -234,7 +234,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
......@@ -273,9 +273,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample
model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample
if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
......
......@@ -120,7 +120,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
......@@ -134,6 +134,82 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = sd_pipe.tokenizer(
prompt,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
prompt_embeds = sd_pipe.text_encoder(text_inputs)[0]
inputs["prompt_embeds"] = prompt_embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = sd_pipe.tokenizer(
p,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
embeds.append(sd_pipe.text_encoder(text_inputs)[0])
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment