Unverified Commit faa6cbc9 authored by kathath's avatar kathath Committed by GitHub
Browse files

Fix repeat of negative prompt (#4335)

fix repeat of negative prompt
parent 306a7bd0
...@@ -442,7 +442,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -442,7 +442,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt): elif type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
...@@ -471,7 +471,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -471,7 +471,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment