Unverified Commit be38b2d7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[UnCLIPPipeline] fix num_images_per_prompt (#1762)

duplicate maks for num_images_per_prompt
parent 32a5d70c
...@@ -143,6 +143,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -143,6 +143,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -172,6 +173,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -172,6 +173,7 @@ class UnCLIPPipeline(DiffusionPipeline):
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1 batch_size * num_images_per_prompt, seq_len, -1
) )
uncond_text_mask = uncond_text_mask.repeat(1, num_images_per_prompt)
# done duplicates # done duplicates
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment