Unverified Commit 867e0c91 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

StableDiffusionLatentUpscalePipeline - positive/negative prompt embeds support (#8947)



* make latent upscaler accept prompt embeds

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 16a3dad4
...@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu ...@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image): def preprocess(image):
warnings.warn( warnings.warn(
...@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
device=device,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
**kwargs,
)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -119,9 +180,29 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -119,9 +180,29 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None or pooled_prompt_embeds is None:
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -134,8 +215,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -134,8 +215,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
...@@ -145,11 +230,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -145,11 +230,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
text_input_ids.to(device), text_input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
text_embeddings = text_encoder_out.hidden_states[-1] prompt_embeds = text_encoder_out.hidden_states[-1]
text_pooler_out = text_encoder_out.pooler_output pooled_prompt_embeds = text_encoder_out.pooler_output
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -184,16 +270,10 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -184,16 +270,10 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
output_hidden_states=True, output_hidden_states=True,
) )
uncond_embeddings = uncond_encoder_out.hidden_states[-1] negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
uncond_pooler_out = uncond_encoder_out.pooler_output negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
return text_embeddings, text_pooler_out return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
def check_inputs(self, prompt, image, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
prompt,
image,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
):
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if ( if (
not isinstance(image, torch.Tensor) not isinstance(image, torch.Tensor)
and not isinstance(image, np.ndarray)
and not isinstance(image, PIL.Image.Image) and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list) and not isinstance(image, list)
): ):
...@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
# verify batch size of prompt and image are same if image is a list or tensor # verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, (list, torch.Tensor)): if isinstance(image, (list, torch.Tensor)):
if prompt is not None:
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
else: else:
batch_size = len(prompt) batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if isinstance(image, list): if isinstance(image, list):
image_batch_size = len(image) image_batch_size = len(image)
else: else:
...@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
...@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
""" """
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, image, callback_steps) self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters # 2. Define call parameters
if prompt is not None:
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
prompt = [""] * batch_size prompt = [""] * batch_size
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings, text_pooler_out = self._encode_prompt( (
prompt, device, do_classifier_free_guidance, negative_prompt prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) )
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
# 4. Preprocess image # 4. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device) image = image.to(dtype=prompt_embeds.dtype, device=device)
if image.shape[1] == 3: if image.shape[1] == 3:
# encode image if not in latent-space yet # encode image if not in latent-space yet
image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
inv_noise_level = (noise_level**2 + 1) ** (-0.5) inv_noise_level = (noise_level**2 + 1) ** (-0.5)
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
image_cond = image_cond.to(text_embeddings.dtype) image_cond = image_cond.to(prompt_embeds.dtype)
noise_level_embed = torch.cat( noise_level_embed = torch.cat(
[ [
torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
], ],
dim=1, dim=1,
) )
timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
# 6. Prepare latent variables # 6. Prepare latent variables
height, width = image.shape[2:] height, width = image.shape[2:]
...@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
num_channels_latents, num_channels_latents,
height * 2, # 2x upscale height * 2, # 2x upscale
width * 2, width * 2,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
noise_pred = self.unet( noise_pred = self.unet(
scaled_model_input, scaled_model_input,
timestep, timestep,
encoder_hidden_states=text_embeddings, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_condition, timestep_cond=timestep_condition,
).sample ).sample
......
...@@ -178,6 +178,46 @@ class StableDiffusionLatentUpscalePipelineFastTests( ...@@ -178,6 +178,46 @@ class StableDiffusionLatentUpscalePipelineFastTests(
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
def test_stable_diffusion_latent_upscaler_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array(
[0.43865365, 0.404124, 0.42618454, 0.44333526, 0.40564927, 0.43818694, 0.4411913, 0.43404633, 0.46392226]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_latent_upscaler_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 256, 256, 3)
expected_slice = np.array(
[0.38730142, 0.35695046, 0.40646142, 0.40967226, 0.3981609, 0.4195988, 0.4248805, 0.430259, 0.45694894]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3) super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment