Unverified Commit bd4df285 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] remove conv_cache from CogVideoX VAE (#9524)



* remove conv cache from the layer and pass as arg instead

* make style

* yiyi's cleaner implementation
Co-Authored-By: default avatarYiYi Xu <yixu310@gmail.com>

* sayak's compiled implementation
Co-Authored-By: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 11542431
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d): ...@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
""" """
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 memory_count = (
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
)
# Set to 2GB, suitable for CuDNN # Set to 2GB, suitable for CuDNN
if memory_count > 2: if memory_count > 2:
...@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
dilation=dilation, dilation=dilation,
) )
self.conv_cache = None def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: ) -> torch.Tensor:
kernel_size = self.time_kernel_size kernel_size = self.time_kernel_size
if kernel_size > 1: if kernel_size > 1:
cached_inputs = ( cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2) inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs return inputs
def _clear_fake_context_parallel_cache(self): def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
del self.conv_cache inputs = self.fake_context_parallel_forward(inputs, conv_cache)
self.conv_cache = None conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0) inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs) output = self.conv(inputs)
return output return output, conv_cache
class CogVideoXSpatialNorm3D(nn.Module): class CogVideoXSpatialNorm3D(nn.Module):
...@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module): ...@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: def forward(
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
if f.shape[2] > 1 and f.shape[2] % 2 == 1: if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
...@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module): ...@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
else: else:
zq = F.interpolate(zq, size=f.shape[-3:]) zq = F.interpolate(zq, size=f.shape[-3:])
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
norm_f = self.norm_layer(f) norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) new_f = norm_f * conv_y + conv_b
return new_f return new_f, new_conv_cache
class CogVideoXResnetBlock3D(nn.Module): class CogVideoXResnetBlock3D(nn.Module):
...@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module): ...@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
self.out_channels = out_channels self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity) self.nonlinearity = get_activation(non_linearity)
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.spatial_norm_dim = spatial_norm_dim
if spatial_norm_dim is None: if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
...@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module): ...@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
inputs: torch.Tensor, inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states = inputs hidden_states = inputs
if zq is not None: if zq is not None:
hidden_states = self.norm1(hidden_states, zq) hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
else: else:
hidden_states = self.norm1(hidden_states) hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states) hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
if temb is not None: if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None: if zq is not None:
hidden_states = self.norm2(hidden_states, zq) hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
else: else:
hidden_states = self.norm2(hidden_states) hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states) hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
inputs = self.conv_shortcut(inputs) if self.use_conv_shortcut:
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
inputs, conv_cache=conv_cache.get("conv_shortcut")
)
else:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs hidden_states = hidden_states + inputs
return hidden_states return hidden_states, new_conv_cache
class CogVideoXDownBlock3D(nn.Module): class CogVideoXDownBlock3D(nn.Module):
...@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module): ...@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
for resnet in self.resnets: r"""Forward method of the `CogVideoXDownBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module): ...@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
return create_forward return create_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
) )
else: else:
hidden_states = resnet(hidden_states, temb, zq) hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states)
return hidden_states return hidden_states, new_conv_cache
class CogVideoXMidBlock3D(nn.Module): class CogVideoXMidBlock3D(nn.Module):
...@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module): ...@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
for resnet in self.resnets: r"""Forward method of the `CogVideoXMidBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module): ...@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
return create_forward return create_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
) )
else: else:
hidden_states = resnet(hidden_states, temb, zq) hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
return hidden_states return hidden_states, new_conv_cache
class CogVideoXUpBlock3D(nn.Module): class CogVideoXUpBlock3D(nn.Module):
...@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module): ...@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class.""" r"""Forward method of the `CogVideoXUpBlock3D` class."""
for resnet in self.resnets:
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module): ...@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
return create_forward return create_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
) )
else: else:
hidden_states = resnet(hidden_states, temb, zq) hidden_states, new_conv_cache[conv_cache_key] = resnet(
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states) hidden_states = upsampler(hidden_states)
return hidden_states return hidden_states, new_conv_cache
class CogVideoXEncoder3D(nn.Module): class CogVideoXEncoder3D(nn.Module):
...@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class.""" r"""The forward method of the `CogVideoXEncoder3D` class."""
hidden_states = self.conv_in(sample)
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
return custom_forward return custom_forward
# 1. Down # 1. Down
for down_block in self.down_blocks: for i, down_block in enumerate(self.down_blocks):
hidden_states = torch.utils.checkpoint.checkpoint( conv_cache_key = f"down_block_{i}"
create_custom_forward(down_block), hidden_states, temb, None hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get(conv_cache_key),
) )
# 2. Mid # 2. Mid
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, None create_custom_forward(self.mid_block),
hidden_states,
temb,
None,
conv_cache=conv_cache.get("mid_block"),
) )
else: else:
# 1. Down # 1. Down
for down_block in self.down_blocks: for i, down_block in enumerate(self.down_blocks):
hidden_states = down_block(hidden_states, temb, None) conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
)
# 2. Mid # 2. Mid
hidden_states = self.mid_block(hidden_states, temb, None) hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
)
# 3. Post-process # 3. Post-process
hidden_states = self.norm_out(hidden_states) hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
class CogVideoXDecoder3D(nn.Module): class CogVideoXDecoder3D(nn.Module):
...@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class.""" r"""The forward method of the `CogVideoXDecoder3D` class."""
hidden_states = self.conv_in(sample)
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
return custom_forward return custom_forward
# 1. Mid # 1. Mid
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, sample create_custom_forward(self.mid_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get("mid_block"),
) )
# 2. Up # 2. Up
for up_block in self.up_blocks: for i, up_block in enumerate(self.up_blocks):
hidden_states = torch.utils.checkpoint.checkpoint( conv_cache_key = f"up_block_{i}"
create_custom_forward(up_block), hidden_states, temb, sample hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
temb,
sample,
conv_cache=conv_cache.get(conv_cache_key),
) )
else: else:
# 1. Mid # 1. Mid
hidden_states = self.mid_block(hidden_states, temb, sample) hidden_states, new_conv_cache["mid_block"] = self.mid_block(
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
)
# 2. Up # 2. Up
for up_block in self.up_blocks: for i, up_block in enumerate(self.up_blocks):
hidden_states = up_block(hidden_states, temb, sample) conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = up_block(
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
)
# 3. Post-process # 3. Post-process
hidden_states = self.norm_out(hidden_states, sample) hidden_states, new_conv_cache["norm_out"] = self.norm_out(
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
)
hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states) hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
...@@ -1091,20 +1183,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1091,20 +1183,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_sample_frames_batch_size frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
conv_cache = None
enc = [] enc = []
for i in range(num_batches): for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame] x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate = self.encoder(x_intermediate) x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.quant_conv is not None: if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate) x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate) enc.append(x_intermediate)
self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2) enc = torch.cat(enc, dim=2)
return enc return enc
@apply_forward_hook @apply_forward_hook
...@@ -1143,7 +1235,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1143,7 +1235,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_latent_frames_batch_size frame_batch_size = self.num_latent_frames_batch_size
num_batches = num_frames // frame_batch_size num_batches = num_frames // frame_batch_size
conv_cache = None
dec = [] dec = []
for i in range(num_batches): for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
...@@ -1151,10 +1245,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1151,10 +1245,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z_intermediate = z[:, :, start_frame:end_frame] z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None: if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate) z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate) z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate) dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2) dec = torch.cat(dec, dim=2)
if not return_dict: if not return_dict:
...@@ -1238,7 +1331,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1238,7 +1331,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for j in range(0, width, overlap_width): for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
conv_cache = None
time = [] time = []
for k in range(num_batches): for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
...@@ -1250,11 +1345,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1250,11 +1345,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
i : i + self.tile_sample_min_height, i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width, j : j + self.tile_sample_min_width,
] ]
tile = self.encoder(tile) tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
if self.quant_conv is not None: if self.quant_conv is not None:
tile = self.quant_conv(tile) tile = self.quant_conv(tile)
time.append(tile) time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2)) row.append(torch.cat(time, dim=2))
rows.append(row) rows.append(row)
...@@ -1315,7 +1410,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1315,7 +1410,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
row = [] row = []
for j in range(0, width, overlap_width): for j in range(0, width, overlap_width):
num_batches = num_frames // frame_batch_size num_batches = num_frames // frame_batch_size
conv_cache = None
time = [] time = []
for k in range(num_batches): for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
...@@ -1329,9 +1426,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1329,9 +1426,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
] ]
if self.post_quant_conv is not None: if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile) tile = self.post_quant_conv(tile)
tile = self.decoder(tile) tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile) time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2)) row.append(torch.cat(time, dim=2))
rows.append(row) rows.append(row)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment