"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f73ed179610653bf100215a54ca2c8a3cba91cf0"
Unverified Commit d06750a5 authored by Zijian Zhou's avatar Zijian Zhou Committed by GitHub
Browse files

Fix autoencoder_kl_wan.py bugs for Wan2.2 VAE (#12335)

* Update autoencoder_kl_wan.py

When using the Wan2.2 VAE, the spatial compression ratio calculated here is incorrect. It should be 16 instead of 8. Pass it in directly via the config to ensure it’s correct here.

* Update autoencoder_kl_wan.py
parent 8c72cd12
...@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
is_residual=is_residual, is_residual=is_residual,
) )
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) self.spatial_compression_ratio = scale_factor_spatial
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time. # to perform decoding of a single video latent at a time.
...@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor): def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape _, _, num_frame, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache() self.clear_cache()
if self.config.patch_size is not None: if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size) x = patchify(x, patch_size=self.config.patch_size)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
iter_ = 1 + (num_frame - 1) // 4 iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment