Unverified Commit 4ea9f89b authored by hlky's avatar hlky Committed by GitHub
Browse files

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review
parent 733b44ac
......@@ -109,14 +109,30 @@ def prompt_clean(text):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
return (encoder_output.latents - latents_mean) * latents_std
else:
raise AttributeError("Could not access latents of provided encoder_output")
......@@ -385,13 +401,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)
video_condition = video_condition.to(device=device, dtype=dtype)
if isinstance(generator, list):
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
latents = latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
......@@ -401,7 +410,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents.device, latents.dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment