Unverified Commit a0acbdc9 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

fix for #7365, prevent pipelines from overriding provided prompt embeds (#7926)



* fix for #7365, prevent pipelines from overriding provided prompt embeds

* fix-copies

* fix implementation

* update

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 5655b22e
...@@ -321,7 +321,9 @@ class StableDiffusionXLKDiffusionPipeline( ...@@ -321,7 +321,9 @@ class StableDiffusionXLKDiffusionPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -380,7 +382,9 @@ class StableDiffusionXLKDiffusionPipeline( ...@@ -380,7 +382,9 @@ class StableDiffusionXLKDiffusionPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -406,7 +406,9 @@ class StableDiffusionXLPipeline( ...@@ -406,7 +406,9 @@ class StableDiffusionXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -465,7 +467,9 @@ class StableDiffusionXLPipeline( ...@@ -465,7 +467,9 @@ class StableDiffusionXLPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -427,7 +427,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -427,7 +427,9 @@ class StableDiffusionXLImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -486,7 +488,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -486,7 +488,9 @@ class StableDiffusionXLImg2ImgPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -531,7 +531,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -531,7 +531,9 @@ class StableDiffusionXLInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -590,7 +592,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -590,7 +592,9 @@ class StableDiffusionXLInpaintPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -333,7 +333,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -333,7 +333,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -385,6 +387,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -385,6 +387,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -423,7 +423,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -423,7 +423,9 @@ class StableDiffusionXLAdapterPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -482,7 +484,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -482,7 +484,9 @@ class StableDiffusionXLAdapterPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
...@@ -705,7 +705,9 @@ class TextToVideoZeroSDXLPipeline( ...@@ -705,7 +705,9 @@ class TextToVideoZeroSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None: if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.hidden_states[-2]
else: else:
...@@ -764,7 +766,9 @@ class TextToVideoZeroSDXLPipeline( ...@@ -764,7 +766,9 @@ class TextToVideoZeroSDXLPipeline(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment