Unverified Commit d204e532 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] improve VAE encode/decode framewise batching (#9684)



* update

* apply suggestions from review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8cabd4a0
...@@ -1182,7 +1182,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1182,7 +1182,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_sample_frames_batch_size frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 # As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None conv_cache = None
enc = [] enc = []
...@@ -1330,7 +1331,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1330,7 +1331,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
row = [] row = []
for j in range(0, width, overlap_width): for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 # As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None conv_cache = None
time = [] time = []
...@@ -1409,7 +1411,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1409,7 +1411,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, overlap_height): for i in range(0, height, overlap_height):
row = [] row = []
for j in range(0, width, overlap_width): for j in range(0, width, overlap_width):
num_batches = num_frames // frame_batch_size num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None conv_cache = None
time = [] time = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment