Unverified Commit af6c0fb7 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] CogVideoX memory optimizations in VAE encode (#9340)

fake context parallel cache, vae encode tiling

(cherry picked from commit bf890bca0e8aed875d6a207f9b826ce894901522)
parent d8a16635
...@@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames. # number of temporal frames.
self.num_latent_frames_batch_size = 2 self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
# We make the minimum height and width of sample for tiling half that of the generally supported # We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2 self.tile_sample_min_height = sample_height // 2
...@@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
""" """
self.use_slicing = False self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
frame_batch_size = self.num_sample_frames_batch_size
enc = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate = self.encoder(x_intermediate)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.Tensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
...@@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns: Returns:
The latent representations of the encoded images. If `return_dict` is True, a The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
""" """
h = self.encoder(x) if self.use_slicing and x.shape[0] > 1:
if self.quant_conv is not None: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = self.quant_conv(h) h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h) posterior = DiagonalGaussianDistribution(h)
if not return_dict: if not return_dict:
return (posterior,) return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
...@@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) )
return b return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = self.num_sample_frames_batch_size
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r""" r"""
Decode a batch of images using a tiled decoder. Decode a batch of images using a tiled decoder.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment