Unverified Commit a85b34e7 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] CogVideoX followups + tiled decoding support (#9150)

* refactor context parallel cache; update torch compile time benchmark

* add tiling support

* make style

* remove num_frames % 8 == 0 requirement

* update default num_frames to original value

* add explanations + refactor

* update torch compile example

* update docs

* update

* clean up if-statements

* address review comments

* add test for vae tiling

* update docs

* update docs

* update docstrings

* add modeling test for cogvideox transformer

* make style
parent 5ffbe14c
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
# CogVideoX # CogVideoX
<!-- TODO: update paper with ArXiv link when ready. --> [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
The abstract from the paper is: The abstract from the paper is:
...@@ -43,43 +41,42 @@ from diffusers import CogVideoXPipeline ...@@ -43,43 +41,42 @@ from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
``` ```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
```python ```python
pipeline.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
``` ```
Finally, compile the components and run inference: Finally, compile the components and run inference:
```python ```python
pipeline.transformer = torch.compile(pipeline.transformer) pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works very well with long and well-described prompts # CogVideoX works well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
``` ```
The [benchmark](TODO: link) results on an 80GB A100 machine are: The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
``` ```
Without torch.compile(): Average inference time: TODO seconds. Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: TODO seconds. With torch.compile(): Average inference time: 76.27 seconds.
``` ```
### Memory optimization
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
## CogVideoXPipeline ## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline [[autodoc]] CogVideoXPipeline
......
...@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d): class CogVideoXSafeConv3d(nn.Conv3d):
""" r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
""" """
...@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args: Args:
in_channels (int): Number of channels in the input tensor. in_channels (`int`): Number of channels in the input tensor.
out_channels (int): Number of output channels. out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (int, optional): Stride of the convolution. Default is 1. stride (`int`, defaults to `1`): Stride of the convolution.
dilation (int, optional): Dilation rate of the convolution. Default is 1. dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (str, optional): Padding mode. Default is "constant". pad_mode (`str`, defaults to `"constant"`): Padding mode.
""" """
def __init__( def __init__(
...@@ -118,19 +118,12 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -118,19 +118,12 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
kernel_size = self.time_kernel_size kernel_size = self.time_kernel_size
if kernel_size == 1: if kernel_size > 1:
return inputs cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = inputs.transpose(0, dim) )
inputs = torch.cat(cached_inputs + [inputs], dim=2)
if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
inputs = inputs.transpose(0, dim).contiguous()
return inputs return inputs
def _clear_fake_context_parallel_cache(self): def _clear_fake_context_parallel_cache(self):
...@@ -138,16 +131,17 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -138,16 +131,17 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
input_parallel = self.fake_context_parallel_forward(inputs) inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache() self._clear_fake_context_parallel_cache()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel) output = self.conv(inputs)
output = output_parallel
return output return output
...@@ -163,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module): ...@@ -163,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
The number of channels for input to group normalization layer, and output of the spatial norm layer. The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`): zq_channels (`int`):
The number of channels for the quantized vector as described in the paper. The number of channels for the quantized vector as described in the paper.
groups (`int`):
Number of groups to separate the channels into for group normalization.
""" """
def __init__( def __init__(
...@@ -197,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module): ...@@ -197,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
A 3D ResNet block used in the CogVideoX model. A 3D ResNet block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (Optional[int], optional): Number of input channels.
Number of output channels. If None, defaults to `in_channels`. Default is None. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
temb_channels (int, optional): Number of time embedding channels. Default is 512. dropout (`float`, defaults to `0.0`):
groups (int, optional): Number of groups for group normalization. Default is 32. Dropout rate.
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. temb_channels (`int`, defaults to `512`):
non_linearity (str, optional): Activation function to use. Default is "swish". Number of time embedding channels.
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. groups (`int`, defaults to `32`):
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. Number of groups to separate the channels into for group normalization.
pad_mode (str, optional): Padding mode. Default is "first". eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
def __init__( def __init__(
...@@ -309,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module): ...@@ -309,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
A downsampling block used in the CogVideoX model. A downsampling block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (int): Number of output channels. Number of input channels.
temb_channels (int): Number of time embedding channels. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
num_layers (int, optional): Number of layers in the block. Default is 1. temb_channels (`int`, defaults to `512`):
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. Number of time embedding channels.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". num_layers (`int`, defaults to `1`):
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. Number of resnet layers.
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. dropout (`float`, defaults to `0.0`):
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. Dropout rate.
compress_time (bool, optional): If True, apply temporal compression. Default is False. resnet_eps (`float`, defaults to `1e-6`):
pad_mode (str, optional): Padding mode. Default is "first". Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
add_downsample (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -405,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module): ...@@ -405,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
A middle block used in the CogVideoX model. A middle block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
temb_channels (int): Number of time embedding channels. Number of input channels.
dropout (float, optional): Dropout rate. Default is 0.0. temb_channels (`int`, defaults to `512`):
num_layers (int, optional): Number of layers in the block. Default is 1. Number of time embedding channels.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. dropout (`float`, defaults to `0.0`):
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". Dropout rate.
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. num_layers (`int`, defaults to `1`):
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. Number of resnet layers.
pad_mode (str, optional): Padding mode. Default is "first". resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -480,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module): ...@@ -480,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
An upsampling block used in the CogVideoX model. An upsampling block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (int): Number of output channels. Number of input channels.
temb_channels (int): Number of time embedding channels. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
num_layers (int, optional): Number of layers in the block. Default is 1. temb_channels (`int`, defaults to `512`):
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. Number of time embedding channels.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". dropout (`float`, defaults to `0.0`):
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. Dropout rate.
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. num_layers (`int`, defaults to `1`):
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. Number of resnet layers.
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. resnet_eps (`float`, defaults to `1e-6`):
compress_time (bool, optional): If True, apply temporal compression. Default is False. Epsilon value for normalization layers.
pad_mode (str, optional): Padding mode. Default is "first". resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, defaults to `16`):
The dimension to use for spatial norm if it is to be used instead of group norm.
add_upsample (`bool`, defaults to `True`):
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
def __init__( def __init__(
...@@ -587,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -587,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
options. options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block. The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block. The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32): norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization. The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -723,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -723,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block. The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block. The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32): norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization. The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -911,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -911,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
norm_num_groups: int = 32, norm_num_groups: int = 32,
temporal_compression_ratio: float = 4, temporal_compression_ratio: float = 4,
sample_size: int = 256, sample_height: int = 480,
sample_width: int = 720,
scaling_factor: float = 1.15258426, scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None, shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None, latents_mean: Optional[Tuple[float]] = None,
...@@ -950,25 +982,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -950,25 +982,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
self.tile_sample_min_size = self.config.sample_size # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
sample_size = ( # recommended because the temporal parts of the VAE, here, are tricky to understand.
self.config.sample_size[0] # If you decode X latent frames together, the number of output frames is:
if isinstance(self.config.sample_size, (list, tuple)) # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
else self.config.sample_size #
# Example with num_latent_frames_batch_size = 2:
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 6 * 8 = 48 frames
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 1 * 9 + 5 * 8 = 49 frames
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
) )
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def clear_fake_context_parallel_cache(self): def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules(): for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d): if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache() module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.Tensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
...@@ -993,8 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -993,8 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,) return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook @apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
""" """
Decode a batch of images. Decode a batch of images.
...@@ -1007,13 +1145,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1007,13 +1145,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
[`~models.vae.DecoderOutput`] or `tuple`: [`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned. returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
""" """
if self.post_quant_conv is not None: # Rough memory assessment:
z = self.post_quant_conv(z) # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
dec = self.decoder(z) # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict: if not return_dict:
return (dec,) return (dec,)
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def forward( def forward(
......
...@@ -37,13 +37,20 @@ class CogVideoXBlock(nn.Module): ...@@ -37,13 +37,20 @@ class CogVideoXBlock(nn.Module):
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters: Parameters:
dim (`int`): The number of channels in the input and output. dim (`int`):
num_attention_heads (`int`): The number of heads to use for multi-head attention. The number of channels in the input and output.
attention_head_dim (`int`): The number of channels in each head. num_attention_heads (`int`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. The number of heads to use for multi-head attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_head_dim (`int`):
attention_bias (: The number of channels in each head.
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`): qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention. Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
...@@ -147,36 +154,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -147,36 +154,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (`int`, defaults to `30`):
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. The number of heads to use for multi-head attention.
in_channels (`int`, *optional*): attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input. The number of channels in the input.
out_channels (`int`, *optional*): out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output. The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. flip_sin_to_cos (`bool`, defaults to `True`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. Whether to flip the sin to cos in the time embedding.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. time_embed_dim (`int`, defaults to `512`):
attention_bias (`bool`, *optional*): Output dimension of timestep embeddings.
Configure if the `TransformerBlocks` attention should contain a bias parameter. text_embed_dim (`int`, defaults to `4096`):
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). Input dimension of text embeddings from the text encoder.
This is fixed during training since it is used to learn a number of position embeddings. num_layers (`int`, defaults to `30`):
patch_size (`int`, *optional*): The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer. The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. temporal_compression_ratio (`int`, defaults to `4`):
num_embeds_ada_norm ( `int`, *optional*): The compression ratio across the temporal dimension. See documentation for `sample_frames`.
The number of diffusion steps used during training. Pass if at least one of the norm_layers is max_text_seq_length (`int`, defaults to `226`):
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are The maximum sequence length of the input text embeddings.
added to the hidden states. During inference, you can denoise for up to but not more steps than activation_fn (`str`, defaults to `"gelu-approximate"`):
`num_embeds_ada_norm`. Activation function to use in feed-forward.
norm_type (`str`, *optional*, defaults to `"layer_norm"`): timestep_activation_fn (`str`, defaults to `"silu"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers. Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. norm_eps (`float`, defaults to `1e-5`):
caption_channels (`int`, *optional*): The epsilon value to use in normalization layers.
The number of channels in the caption embeddings. spatial_interpolation_scale (`float`, defaults to `1.875`):
video_length (`int`, *optional*): Scaling factor to apply in 3D positional embeddings across spatial dimensions.
The number of frames in the video-like data. temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -186,7 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -186,7 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self, self,
num_attention_heads: int = 30, num_attention_heads: int = 30,
attention_head_dim: int = 64, attention_head_dim: int = 64,
in_channels: Optional[int] = 16, in_channels: int = 16,
out_channels: Optional[int] = 16, out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
...@@ -304,7 +328,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -304,7 +328,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :] hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 5. Transformer blocks # 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -331,11 +355,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -331,11 +355,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm_final(hidden_states) hidden_states = self.norm_final(hidden_states)
# 6. Final block # 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify # 6. Unpatchify
p = self.config.patch_size p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
......
...@@ -332,20 +332,11 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -332,20 +332,11 @@ class CogVideoXPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def decode_latents(self, latents: torch.Tensor, num_seconds: int): def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
frames = [] frames = self.vae.decode(latents).sample
for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames)
self.vae.clear_fake_context_parallel_cache()
frames = torch.cat(frames, dim=2)
return frames return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -438,8 +429,7 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -438,8 +429,7 @@ class CogVideoXPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: int = 480,
width: int = 720, width: int = 720,
num_frames: int = 48, num_frames: int = 49,
fps: int = 8,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
guidance_scale: float = 6, guidance_scale: float = 6,
...@@ -534,9 +524,10 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -534,9 +524,10 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
assert ( if num_frames > 49:
num_frames <= 48 and num_frames % fps == 0 and fps == 8 raise ValueError(
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
...@@ -593,7 +584,6 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -593,7 +584,6 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
latent_channels, latent_channels,
...@@ -673,7 +663,7 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -673,7 +663,7 @@ class CogVideoXPipeline(DiffusionPipeline):
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents, num_frames // fps) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
video = latents video = latents
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
...@@ -125,11 +125,6 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -125,11 +125,6 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Cannot reduce because convolution kernel becomes bigger than sample # Cannot reduce because convolution kernel becomes bigger than sample
"height": 16, "height": 16,
"width": 16, "width": 16,
# TODO(aryan): improve this
# Cannot make this lower due to assert condition in pipeline at the moment.
# The reason why 8 can't be used here is due to how context-parallel cache works where the first
# second of video is decoded from latent frames (0, 3) instead of [(0, 2), (2, 3)]. If 8 is used,
# the number of output frames that you get are 5.
"num_frames": 8, "num_frames": 8,
"max_sequence_length": 16, "max_sequence_length": 16,
"output_type": "pt", "output_type": "pt",
...@@ -148,8 +143,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -148,8 +143,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16) expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10) self.assertLessEqual(max_diff, 1e10)
...@@ -250,6 +245,36 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -250,6 +245,36 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"Attention slicing should not affect the inference results", "Attention slicing should not affect the inference results",
) )
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment