Unverified Commit e7f3a737 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix Wan I2V prepare_latents dtype (#11371)

update
parent 7a4a126d
...@@ -409,7 +409,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -409,7 +409,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2, dim=2,
) )
video_condition = video_condition.to(device=device, dtype=dtype) video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
latents_mean = ( latents_mean = (
torch.tensor(self.vae.config.latents_mean) torch.tensor(self.vae.config.latents_mean)
...@@ -429,6 +429,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -429,6 +429,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment