"tests/vscode:/vscode.git/clone" did not exist on "44a9faad7ab60d6bbcc08e094c0eaa86a0d73063"
Unverified Commit a38dd795 authored by Yushu's avatar Yushu Committed by GitHub
Browse files

[Pipeline] Fix error of SVD pipeline when num_videos_per_prompt > 1 (#7786)



swap the order for do_classifier_free_guidance concat with repeat
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b1c5817a
...@@ -199,6 +199,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -199,6 +199,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
image = image.to(device=device) image = image.to(device=device)
image_latents = self.vae.encode(image).latent_dist.mode() image_latents = self.vae.encode(image).latent_dist.mode()
# duplicate image_latents for each generation per prompt, using mps friendly method
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
negative_image_latents = torch.zeros_like(image_latents) negative_image_latents = torch.zeros_like(image_latents)
...@@ -207,9 +210,6 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -207,9 +210,6 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# to avoid doing two forward passes # to avoid doing two forward passes
image_latents = torch.cat([negative_image_latents, image_latents]) image_latents = torch.cat([negative_image_latents, image_latents])
# duplicate image_latents for each generation per prompt, using mps friendly method
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
return image_latents return image_latents
def _get_add_time_ids( def _get_add_time_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment