Unverified Commit 3517fb94 authored by MilkClouds's avatar MilkClouds Committed by GitHub
Browse files

fix: enabled num_images_per_prompt>1 for lpw_stable_diffusion_xl (community pipeline) (#5807)

* fix: enabled num_images_per_prompt>1 for lpw_stable_diffusion_xl

* style: fixed isort
parent cdadb023
...@@ -249,6 +249,7 @@ def get_weighted_text_embeddings_sdxl( ...@@ -249,6 +249,7 @@ def get_weighted_text_embeddings_sdxl(
prompt_2: str = None, prompt_2: str = None,
neg_prompt: str = "", neg_prompt: str = "",
neg_prompt_2: str = None, neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
): ):
""" """
This function can process long prompt with weights, no length limitation This function can process long prompt with weights, no length limitation
...@@ -260,6 +261,7 @@ def get_weighted_text_embeddings_sdxl( ...@@ -260,6 +261,7 @@ def get_weighted_text_embeddings_sdxl(
prompt_2 (str) prompt_2 (str)
neg_prompt (str) neg_prompt (str)
neg_prompt_2 (str) neg_prompt_2 (str)
num_images_per_prompt (int)
Returns: Returns:
prompt_embeds (torch.Tensor) prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor) neg_prompt_embeds (torch.Tensor)
...@@ -383,6 +385,22 @@ def get_weighted_text_embeddings_sdxl( ...@@ -383,6 +385,22 @@ def get_weighted_text_embeddings_sdxl(
prompt_embeds = torch.cat(embeds, dim=1) prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1) negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
...@@ -1096,7 +1114,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo ...@@ -1096,7 +1114,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = get_weighted_text_embeddings_sdxl(pipe=self, prompt=prompt, neg_prompt=negative_prompt) ) = get_weighted_text_embeddings_sdxl(
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment