"vscode:/vscode.git/clone" did not exist on "784beee9697c0e37c5ca12ad63a9e1c6eb90bd1a"
Unverified Commit 78be4007 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[PixArt-Alpha] fix mask feature condition. (#5695)

* fix mask feature condition.

* debug

* remove identical test

* set correct

* Empty-Commit
parent c803a8f8
......@@ -156,6 +156,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if device is None:
device = self._execution_device
......@@ -253,7 +255,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds = None
# Perform additional masking.
if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
......
......@@ -120,7 +120,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
# set all optional components to None
......@@ -155,7 +154,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
output_loaded = pipe_loaded(**inputs)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment