"tests/python/vscode:/vscode.git/clone" did not exist on "1328baf7cbfc908fa179df3011b8d8eee064958c"
Unverified Commit cac7adab authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax SDXL] fix zero out sdxl (#5203)

parent a584d42c
......@@ -188,9 +188,10 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]
if neg_prompt_ids is None:
neg_prompt_ids = self.prepare_inputs([""] * batch_size)
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
neg_prompt_embeds = jnp.zeros_like(prompt_embeds)
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
else:
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
add_time_ids = self._get_add_time_ids(
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment