Unverified Commit 794f7e49 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Implement framewise encoding/decoding in LTX Video VAE (#10488)



* add framewise decode

* add framewise encode, refactor tiled encode/decode

* add sanity test tiling for ltx

* run make style

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

---------
Co-authored-by: default avatarPham Hong Vinh <vinhph3@vng.com.vn>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
parent 9fc9c6dd
...@@ -1010,10 +1010,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1010,10 +1010,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# The minimal tile height and width for spatial tiling to be used # The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512 self.tile_sample_min_height = 512
self.tile_sample_min_width = 512 self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles # The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448 self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448 self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
...@@ -1023,8 +1025,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1023,8 +1025,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None, tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None, tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None, tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None: ) -> None:
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
...@@ -1046,8 +1050,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1046,8 +1050,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_tiling = True self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None: def disable_tiling(self) -> None:
r""" r"""
...@@ -1073,18 +1079,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1073,18 +1079,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape batch_size, num_channels, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x) return self.tiled_encode(x)
if self.use_framewise_encoding: enc = self.encoder(x)
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
return enc return enc
...@@ -1121,19 +1122,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1121,19 +1122,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
batch_size, num_channels, num_frames, height, width = z.shape batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict) return self.tiled_decode(z, temb, return_dict=return_dict)
if self.use_framewise_decoding: dec = self.decoder(z, temb)
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z, temb)
if not return_dict: if not return_dict:
return (dec,) return (dec,)
...@@ -1189,6 +1186,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1189,6 +1186,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) )
return b return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
...@@ -1217,17 +1222,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1217,17 +1222,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, self.tile_sample_stride_height): for i in range(0, height, self.tile_sample_stride_height):
row = [] row = []
for j in range(0, width, self.tile_sample_stride_width): for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding: time = self.encoder(
# TODO(aryan): requires investigation x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
raise NotImplementedError( )
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
row.append(time) row.append(time)
rows.append(row) rows.append(row)
...@@ -1283,17 +1280,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1283,17 +1280,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, tile_latent_stride_height): for i in range(0, height, tile_latent_stride_height):
row = [] row = []
for j in range(0, width, tile_latent_stride_width): for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding: time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
row.append(time) row.append(time)
rows.append(row) rows.append(row)
...@@ -1318,6 +1305,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1318,6 +1305,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc
def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
...@@ -1334,5 +1389,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1334,5 +1389,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.mode() z = posterior.mode()
dec = self.decode(z, temb) dec = self.decode(z, temb)
if not return_dict: if not return_dict:
return (dec,) return (dec.sample,)
return dec return dec
...@@ -167,3 +167,34 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest. ...@@ -167,3 +167,34 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
pass pass
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment