"...composable_kernel_rocm.git" did not exist on "cd8de112189710f9bd3a2a817e50ee5e583b4397"
Unverified Commit b562b661 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071)

* add text embeds to sd

* add text embeds to sd

* finish tests

* finish

* finish

* make style

* fix tests

* make style

* make style

* up

* better docs

* fix

* fix

* new try

* up

* up

* finish
parent c1184918
......@@ -136,12 +136,21 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -149,12 +158,26 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
......@@ -165,8 +188,12 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
......@@ -177,19 +204,21 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -209,7 +238,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
......@@ -223,23 +252,27 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
......@@ -325,8 +358,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
......@@ -335,6 +368,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
......@@ -344,8 +379,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
......@@ -358,8 +394,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
......@@ -372,6 +409,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -422,6 +466,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
......@@ -431,13 +478,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Preprocess image
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......@@ -445,7 +498,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
......@@ -460,7 +513,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -493,7 +546,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level
).sample
# perform guidance
......
......@@ -217,7 +217,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -253,16 +253,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
prompt_embeds = prompt_embeds[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -299,16 +299,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text
if enable_safety_guidance:
......@@ -329,15 +329,15 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings])
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
......@@ -390,10 +390,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -405,6 +411,32 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......@@ -583,7 +615,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
warnings.warn("Safety checker disabled!")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
)
......@@ -598,7 +630,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -621,7 +653,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -680,7 +712,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance
image, device, prompt_embeds.dtype, enable_safety_guidance
)
# 10. Convert to PIL
......
......@@ -153,15 +153,15 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
......@@ -176,16 +176,16 @@ class UnCLIPPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
......@@ -199,12 +199,12 @@ class UnCLIPPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
return prompt_embeds, text_encoder_hidden_states, text_mask
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
......@@ -336,7 +336,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
......@@ -349,7 +349,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
prior_latents,
......@@ -363,7 +363,7 @@ class UnCLIPPipeline(DiffusionPipeline):
predicted_image_embedding = self.prior(
latent_model_input,
timestep=t,
proj_embedding=text_embeddings,
proj_embedding=prompt_embeds,
encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask,
).predicted_image_embedding
......@@ -397,7 +397,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
......
......@@ -136,10 +136,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
text_mask = text_inputs.attention_mask.bool().to(device)
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
......@@ -155,16 +155,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
......@@ -178,12 +178,12 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
return prompt_embeds, text_encoder_hidden_states, text_mask
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype
......@@ -314,7 +314,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
do_classifier_free_guidance = decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)
......@@ -323,7 +323,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
......
......@@ -52,7 +52,7 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance):
if do_classifier_free_guidance:
# Add the classifier free guidance embeddings to the image embeddings
image_embeddings_batch_size = image_embeddings.shape[0]
......@@ -63,15 +63,15 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
# The image embeddings batch size and the text embeddings batch size are equal
assert image_embeddings.shape[0] == text_embeddings.shape[0]
assert image_embeddings.shape[0] == prompt_embeds.shape[0]
batch_size = text_embeddings.shape[0]
batch_size = prompt_embeds.shape[0]
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
# adding CLIP embeddings to the existing timestep embedding, ...
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
time_projected_prompt_embeds = self.embedding_proj(prompt_embeds)
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds
# ... and by projecting CLIP embeddings into four
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
......
......@@ -187,7 +187,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -227,16 +227,16 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -255,30 +255,30 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -313,18 +313,18 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = self.image_encoder(pixel_values)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......@@ -524,9 +524,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompts
text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1)
dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1)
prompt_types = ("text", "image")
# 4. Prepare timesteps
......
......@@ -114,7 +114,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -173,18 +173,18 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = self.image_encoder(pixel_values)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
......
......@@ -138,7 +138,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`):
prompt to be encoded
device: (`torch.device`):
torch device
......@@ -181,16 +181,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else:
attention_mask = None
text_embeddings = self.text_encoder(
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
......@@ -227,23 +227,23 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
......@@ -273,10 +273,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -288,6 +294,32 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
......@@ -412,7 +444,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
......@@ -427,7 +459,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
num_channels_latents,
height,
width,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -443,7 +475,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance
if do_classifier_free_guidance:
......
......@@ -120,7 +120,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0]
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
......@@ -128,15 +128,15 @@ class VQDiffusionPipeline(DiffusionPipeline):
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.learnable:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings
negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
else:
uncond_tokens = [""] * batch_size
......@@ -148,21 +148,21 @@ class VQDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings
return prompt_embeds
@torch.no_grad()
def __call__(
......@@ -234,7 +234,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
......@@ -273,9 +273,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample
model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample
if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
......
......@@ -120,7 +120,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
......@@ -134,6 +134,82 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = sd_pipe.tokenizer(
prompt,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
prompt_embeds = sd_pipe.text_encoder(text_inputs)[0]
inputs["prompt_embeds"] = prompt_embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = sd_pipe.tokenizer(
p,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
embeds.append(sd_pipe.text_encoder(text_inputs)[0])
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment