Unverified Commit a0acbdc9 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

fix for #7365, prevent pipelines from overriding provided prompt embeds (#7926)



* fix for #7365, prevent pipelines from overriding provided prompt embeds

* fix-copies

* fix implementation

* update

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 5655b22e
......@@ -827,7 +827,9 @@ class SDXLLongPromptWeightingPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
......@@ -879,7 +881,8 @@ class SDXLLongPromptWeightingPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -290,7 +290,9 @@ class DemoFusionSDXLPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
......@@ -342,7 +344,8 @@ class DemoFusionSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -628,7 +628,9 @@ class StyleAlignedSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -688,7 +690,8 @@ class StyleAlignedSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -359,7 +359,9 @@ class StableDiffusionXLControlNetAdapterPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -419,7 +421,8 @@ class StableDiffusionXLControlNetAdapterPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -507,7 +507,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -567,7 +569,8 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -394,7 +394,9 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -454,7 +456,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -390,7 +390,9 @@ class StableDiffusionXLPipelineIpex(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -450,7 +452,8 @@ class StableDiffusionXLPipelineIpex(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -438,7 +438,9 @@ class AnimateDiffSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -497,8 +499,10 @@ class AnimateDiffSDXLPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -406,7 +406,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -465,8 +467,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -415,7 +415,9 @@ class StableDiffusionXLControlNetPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -474,8 +476,10 @@ class StableDiffusionXLControlNetPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -408,7 +408,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -467,8 +469,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -388,7 +388,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -447,8 +449,10 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -397,7 +397,9 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -456,8 +458,10 @@ class StableDiffusionXLControlNetUnionPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -422,7 +422,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -481,8 +483,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -336,7 +336,9 @@ class StableDiffusionXLControlNetXSPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -395,8 +397,10 @@ class StableDiffusionXLControlNetXSPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -421,7 +421,9 @@ class StableDiffusionXLControlNetPAGPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -480,8 +482,10 @@ class StableDiffusionXLControlNetPAGPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -413,7 +413,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -472,8 +474,10 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -415,7 +415,9 @@ class StableDiffusionXLPAGPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -474,8 +476,10 @@ class StableDiffusionXLPAGPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -436,7 +436,9 @@ class StableDiffusionXLPAGImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -495,8 +497,10 @@ class StableDiffusionXLPAGImg2ImgPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
......@@ -526,7 +526,9 @@ class StableDiffusionXLPAGInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
......@@ -585,8 +587,10 @@ class StableDiffusionXLPAGInpaintPipeline(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment