Unverified Commit 363d1ab7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Wan VAE move scaling to pipeline (#10998)

parent 6a0137eb
...@@ -715,11 +715,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): ...@@ -715,11 +715,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
) -> None: ) -> None:
super().__init__() super().__init__()
# Store normalization parameters as tensors
self.mean = torch.tensor(latents_mean)
self.std = torch.tensor(latents_std)
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
self.z_dim = z_dim self.z_dim = z_dim
self.temperal_downsample = temperal_downsample self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
...@@ -751,7 +746,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): ...@@ -751,7 +746,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
scale = self.scale.type_as(x)
self.clear_cache() self.clear_cache()
## cache ## cache
t = x.shape[2] t = x.shape[2]
...@@ -770,8 +764,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): ...@@ -770,8 +764,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
enc = self.quant_conv(out) enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
enc = torch.cat([mu, logvar], dim=1) enc = torch.cat([mu, logvar], dim=1)
self.clear_cache() self.clear_cache()
return enc return enc
...@@ -798,10 +790,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): ...@@ -798,10 +790,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
return (posterior,) return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.post_quant_conv(z) x = self.post_quant_conv(z)
...@@ -835,8 +825,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): ...@@ -835,8 +825,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned. returned.
""" """
scale = self.scale.type_as(z) decoded = self._decode(z).sample
decoded = self._decode(z, scale).sample
if not return_dict: if not return_dict:
return (decoded,) return (decoded,)
......
...@@ -563,6 +563,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -563,6 +563,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type) video = self.video_processor.postprocess_video(video, output_type=output_type)
else: else:
......
...@@ -392,6 +392,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -392,6 +392,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0 mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = mask_lat_size[:, :, 0:1]
...@@ -654,6 +665,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -654,6 +665,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type) video = self.video_processor.postprocess_video(video, output_type=output_type)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment