Unverified Commit 867e0c91 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

StableDiffusionLatentUpscalePipeline - positive/negative prompt embeds support (#8947)



* make latent upscaler accept prompt embeds

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 16a3dad4
......@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image):
warnings.warn(
......@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
device=device,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
**kwargs,
)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -119,9 +180,29 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None or pooled_prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
......@@ -134,8 +215,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
......@@ -145,11 +230,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
text_input_ids.to(device),
output_hidden_states=True,
)
text_embeddings = text_encoder_out.hidden_states[-1]
text_pooler_out = text_encoder_out.pooler_output
prompt_embeds = text_encoder_out.hidden_states[-1]
pooled_prompt_embeds = text_encoder_out.pooler_output
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
......@@ -184,16 +270,10 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
output_hidden_states=True,
)
uncond_embeddings = uncond_encoder_out.hidden_states[-1]
uncond_pooler_out = uncond_encoder_out.pooler_output
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
return text_embeddings, text_pooler_out
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
......@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
def check_inputs(
self,
prompt,
image,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
):
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, np.ndarray)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
......@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, (list, torch.Tensor)):
if prompt is not None:
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if isinstance(image, list):
image_batch_size = len(image)
else:
......@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
......@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
"""
# 1. Check inputs
self.check_inputs(prompt, image, callback_steps)
self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None:
batch_size = 1 if isinstance(prompt, str) else len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
prompt = [""] * batch_size
# 3. Encode input prompt
text_embeddings, text_pooler_out = self._encode_prompt(
prompt, device, do_classifier_free_guidance, negative_prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
image = image.to(dtype=prompt_embeds.dtype, device=device)
if image.shape[1] == 3:
# encode image if not in latent-space yet
image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
inv_noise_level = (noise_level**2 + 1) ** (-0.5)
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
image_cond = image_cond.to(text_embeddings.dtype)
image_cond = image_cond.to(prompt_embeds.dtype)
noise_level_embed = torch.cat(
[
torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
],
dim=1,
)
timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
# 6. Prepare latent variables
height, width = image.shape[2:]
......@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
num_channels_latents,
height * 2, # 2x upscale
width * 2,
text_embeddings.dtype,
prompt_embeds.dtype,
device,
generator,
latents,
......@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
noise_pred = self.unet(
scaled_model_input,
timestep,
encoder_hidden_states=text_embeddings,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_condition,
).sample
......
......@@ -178,6 +178,46 @@ class StableDiffusionLatentUpscalePipelineFastTests(
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_stable_diffusion_latent_upscaler_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array(
[0.43865365, 0.404124, 0.42618454, 0.44333526, 0.40564927, 0.43818694, 0.4411913, 0.43404633, 0.46392226]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_latent_upscaler_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 256, 256, 3)
expected_slice = np.array(
[0.38730142, 0.35695046, 0.40646142, 0.40967226, 0.3981609, 0.4195988, 0.4248805, 0.430259, 0.45694894]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment