Unverified Commit a0acbdc9 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

fix for #7365, prevent pipelines from overriding provided prompt embeds (#7926)



* fix for #7365, prevent pipelines from overriding provided prompt embeds

* fix-copies

* fix implementation

* update

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 5655b22e
...@@ -827,7 +827,9 @@ class SDXLLongPromptWeightingPipeline( ...@@ -827,7 +827,9 @@ class SDXLLongPromptWeightingPipeline(
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -879,7 +881,8 @@ class SDXLLongPromptWeightingPipeline( ...@@ -879,7 +881,8 @@ class SDXLLongPromptWeightingPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -290,7 +290,9 @@ class DemoFusionSDXLPipeline( ...@@ -290,7 +290,9 @@ class DemoFusionSDXLPipeline(
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -342,7 +344,8 @@ class DemoFusionSDXLPipeline( ...@@ -342,7 +344,8 @@ class DemoFusionSDXLPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -628,7 +628,9 @@ class StyleAlignedSDXLPipeline( ...@@ -628,7 +628,9 @@ class StyleAlignedSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -688,7 +690,8 @@ class StyleAlignedSDXLPipeline( ...@@ -688,7 +690,8 @@ class StyleAlignedSDXLPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -359,7 +359,9 @@ class StableDiffusionXLControlNetAdapterPipeline( ...@@ -359,7 +359,9 @@ class StableDiffusionXLControlNetAdapterPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -419,7 +421,8 @@ class StableDiffusionXLControlNetAdapterPipeline( ...@@ -419,7 +421,8 @@ class StableDiffusionXLControlNetAdapterPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -507,7 +507,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( ...@@ -507,7 +507,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -567,7 +569,8 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( ...@@ -567,7 +569,8 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -394,7 +394,9 @@ class StableDiffusionXLDifferentialImg2ImgPipeline( ...@@ -394,7 +394,9 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -454,7 +456,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline( ...@@ -454,7 +456,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -390,7 +390,9 @@ class StableDiffusionXLPipelineIpex( ...@@ -390,7 +390,9 @@ class StableDiffusionXLPipelineIpex(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -450,7 +452,8 @@ class StableDiffusionXLPipelineIpex( ...@@ -450,7 +452,8 @@ class StableDiffusionXLPipelineIpex(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -438,7 +438,9 @@ class AnimateDiffSDXLPipeline( ...@@ -438,7 +438,9 @@ class AnimateDiffSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -497,8 +499,10 @@ class AnimateDiffSDXLPipeline( ...@@ -497,8 +499,10 @@ class AnimateDiffSDXLPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -406,7 +406,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -406,7 +406,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -465,8 +467,10 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -465,8 +467,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -415,7 +415,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -415,7 +415,9 @@ class StableDiffusionXLControlNetPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -474,8 +476,10 @@ class StableDiffusionXLControlNetPipeline( ...@@ -474,8 +476,10 @@ class StableDiffusionXLControlNetPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -408,7 +408,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -408,7 +408,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -467,8 +469,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -467,8 +469,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -388,7 +388,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -388,7 +388,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -447,8 +449,10 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -447,8 +449,10 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -397,7 +397,9 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -397,7 +397,9 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -456,8 +458,10 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -456,8 +458,10 @@ class StableDiffusionXLControlNetUnionPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -422,7 +422,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -422,7 +422,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -481,8 +483,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -481,8 +483,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -336,7 +336,9 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -336,7 +336,9 @@ class StableDiffusionXLControlNetXSPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -395,8 +397,10 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -395,8 +397,10 @@ class StableDiffusionXLControlNetXSPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -421,7 +421,9 @@ class StableDiffusionXLControlNetPAGPipeline( ...@@ -421,7 +421,9 @@ class StableDiffusionXLControlNetPAGPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -480,8 +482,10 @@ class StableDiffusionXLControlNetPAGPipeline( ...@@ -480,8 +482,10 @@ class StableDiffusionXLControlNetPAGPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -413,7 +413,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline( ...@@ -413,7 +413,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -472,8 +474,10 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline( ...@@ -472,8 +474,10 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -415,7 +415,9 @@ class StableDiffusionXLPAGPipeline( ...@@ -415,7 +415,9 @@ class StableDiffusionXLPAGPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -474,8 +476,10 @@ class StableDiffusionXLPAGPipeline( ...@@ -474,8 +476,10 @@ class StableDiffusionXLPAGPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -436,7 +436,9 @@ class StableDiffusionXLPAGImg2ImgPipeline( ...@@ -436,7 +436,9 @@ class StableDiffusionXLPAGImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -495,8 +497,10 @@ class StableDiffusionXLPAGImg2ImgPipeline( ...@@ -495,8 +497,10 @@ class StableDiffusionXLPAGImg2ImgPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
...@@ -526,7 +526,9 @@ class StableDiffusionXLPAGInpaintPipeline( ...@@ -526,7 +526,9 @@ class StableDiffusionXLPAGInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -585,8 +587,10 @@ class StableDiffusionXLPAGInpaintPipeline( ...@@ -585,8 +587,10 @@ class StableDiffusionXLPAGInpaintPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0] if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds_list.append(negative_prompt_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment