Unverified Commit 363d1ab7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Wan VAE move scaling to pipeline (#10998)

parent 6a0137eb
......@@ -715,11 +715,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
) -> None:
super().__init__()
# Store normalization parameters as tensors
self.mean = torch.tensor(latents_mean)
self.std = torch.tensor(latents_std)
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
......@@ -751,7 +746,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor:
scale = self.scale.type_as(x)
self.clear_cache()
## cache
t = x.shape[2]
......@@ -770,8 +764,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
......@@ -798,10 +790,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
# z: [b,c,t,h,w]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
iter_ = z.shape[2]
x = self.post_quant_conv(z)
......@@ -835,8 +825,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
scale = self.scale.type_as(z)
decoded = self._decode(z, scale).sample
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
......
......@@ -563,6 +563,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
......
......@@ -392,6 +392,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
......@@ -654,6 +665,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment