Unverified Commit 6a51427b authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix multi-prompt inference (#10103)



* Fix multi-prompt inference

Fix generation of multiple images with multiple prompts, e.g len(prompts)>1, num_images_per_prompt>1

* make

* fix copies

---------
Co-authored-by: default avatarNikita Balabin <nikita@mxl.ru>
parent 5effcd3e
...@@ -251,13 +251,6 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -251,13 +251,6 @@ class AllegroPipeline(DiffusionPipeline):
if device is None: if device is None:
device = self._execution_device device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
...@@ -302,12 +295,12 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -302,12 +295,12 @@ class AllegroPipeline(DiffusionPipeline):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
...@@ -334,10 +327,10 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -334,10 +327,10 @@ class AllegroPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
else: else:
negative_prompt_embeds = None negative_prompt_embeds = None
negative_prompt_attention_mask = None negative_prompt_attention_mask = None
......
...@@ -227,13 +227,6 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -227,13 +227,6 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
if device is None: if device is None:
device = self._execution_device device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
...@@ -278,12 +271,12 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -278,12 +271,12 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
...@@ -310,10 +303,10 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -310,10 +303,10 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else: else:
negative_prompt_embeds = None negative_prompt_embeds = None
negative_prompt_attention_mask = None negative_prompt_attention_mask = None
......
...@@ -338,13 +338,6 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -338,13 +338,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
if device is None: if device is None:
device = self._execution_device device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
...@@ -389,12 +382,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -389,12 +382,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
...@@ -421,10 +414,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -421,10 +414,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else: else:
negative_prompt_embeds = None negative_prompt_embeds = None
negative_prompt_attention_mask = None negative_prompt_attention_mask = None
......
...@@ -264,13 +264,6 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -264,13 +264,6 @@ class PixArtSigmaPipeline(DiffusionPipeline):
if device is None: if device is None:
device = self._execution_device device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
...@@ -315,12 +308,12 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -315,12 +308,12 @@ class PixArtSigmaPipeline(DiffusionPipeline):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
...@@ -347,10 +340,10 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -347,10 +340,10 @@ class PixArtSigmaPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else: else:
negative_prompt_embeds = None negative_prompt_embeds = None
negative_prompt_attention_mask = None negative_prompt_attention_mask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment