Unverified Commit 3517fb94 authored by MilkClouds's avatar MilkClouds Committed by GitHub
Browse files

fix: enabled num_images_per_prompt>1 for lpw_stable_diffusion_xl (community pipeline) (#5807)

* fix: enabled num_images_per_prompt>1 for lpw_stable_diffusion_xl

* style: fixed isort
parent cdadb023
......@@ -249,6 +249,7 @@ def get_weighted_text_embeddings_sdxl(
prompt_2: str = None,
neg_prompt: str = "",
neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
):
"""
This function can process long prompt with weights, no length limitation
......@@ -260,6 +261,7 @@ def get_weighted_text_embeddings_sdxl(
prompt_2 (str)
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
......@@ -383,6 +385,22 @@ def get_weighted_text_embeddings_sdxl(
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......@@ -1096,7 +1114,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = get_weighted_text_embeddings_sdxl(pipe=self, prompt=prompt, neg_prompt=negative_prompt)
) = get_weighted_text_embeddings_sdxl(
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment