Unverified Commit 5d28d221 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] fix combined pipeline's num_images_per_prompt (#4989)

* fix encode_prompt

* added prompt_embeds and negative_prompt_embeds

* prompt_embeds for the prior only
parent 73bf620d
...@@ -330,7 +330,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -330,7 +330,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
# 2. Encode caption # 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
image_embeddings.size(0) * num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
) )
text_encoder_hidden_states = ( text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
......
...@@ -154,6 +154,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -154,6 +154,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
decoder_timesteps: Optional[List[float]] = None, decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0, decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
...@@ -165,10 +167,17 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -165,10 +167,17 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation for the prior and decoder.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512): height (`int`, *optional*, defaults to 512):
...@@ -221,13 +230,15 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -221,13 +230,15 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt, prompt=prompt if prompt_embeds is None else None,
height=height, height=height,
width=width, width=width,
num_inference_steps=prior_num_inference_steps, num_inference_steps=prior_num_inference_steps,
timesteps=prior_timesteps, timesteps=prior_timesteps,
guidance_scale=prior_guidance_scale, guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt, negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
generator=generator, generator=generator,
latents=latents, latents=latents,
......
...@@ -150,41 +150,57 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -150,41 +150,57 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
def encode_prompt( def encode_prompt(
self, self,
prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
prompt=None,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
): ):
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
# get prompt text embeddings batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_tensors="pt", if prompt_embeds is None:
) # get prompt text embeddings
text_input_ids = text_inputs.input_ids text_inputs = self.tokenizer(
attention_mask = text_inputs.attention_mask prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): text_encoder_output = self.text_encoder(
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_input_ids.to(device), attention_mask=attention_mask.to(device)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] prompt_embeds = text_encoder_output.last_hidden_state
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_text_encoder_hidden_states = None if negative_prompt_embeds is None and do_classifier_free_guidance:
if do_classifier_free_guidance:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -215,17 +231,17 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -215,17 +231,17 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
) )
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_text_encoder_hidden_states.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
batch_size * num_images_per_prompt, seq_len, -1 negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
)
# done duplicates # done duplicates
return text_encoder_hidden_states, uncond_text_encoder_hidden_states return prompt_embeds, negative_prompt_embeds
def check_inputs( def check_inputs(
self, self,
...@@ -264,13 +280,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -264,13 +280,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Optional[Union[str, List[str]]] = None,
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
num_inference_steps: int = 60, num_inference_steps: int = 60,
timesteps: List[float] = None, timesteps: List[float] = None,
guidance_scale: float = 8.0, guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
...@@ -304,6 +322,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -304,6 +322,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`). if `decoder_guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
...@@ -345,7 +370,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -345,7 +370,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
# 2. Encode caption # 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment