Unverified Commit e27142ac authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

[`Wan`] Fix VAE sampling mode in `WanVideoToVideoPipeline` (#11639)

* fix: vae sampling mode

* fix a typo
parent 8e88495d
...@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
) )
if latents is None: if latents is None:
if isinstance(generator, list): init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype) init_latents = torch.cat(init_latents, dim=0).to(dtype)
...@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if hasattr(self.scheduler, "add_noise"): if hasattr(self.scheduler, "add_noise"):
latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = self.scheduler.add_noise(init_latents, noise, timestep)
else: else:
latents = self.scheduelr.scale_noise(init_latents, timestep, noise) latents = self.scheduler.scale_noise(init_latents, timestep, noise)
else: else:
latents = latents.to(device) latents = latents.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment