Unverified Commit be38b2d7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[UnCLIPPipeline] fix num_images_per_prompt (#1762)

duplicate maks for num_images_per_prompt
parent 32a5d70c
......@@ -143,6 +143,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
......@@ -172,6 +173,7 @@ class UnCLIPPipeline(DiffusionPipeline):
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat(1, num_images_per_prompt)
# done duplicates
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment