Unverified Commit a85b34e7 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] CogVideoX followups + tiled decoding support (#9150)

* refactor context parallel cache; update torch compile time benchmark

* add tiling support

* make style

* remove num_frames % 8 == 0 requirement

* update default num_frames to original value

* add explanations + refactor

* update torch compile example

* update docs

* update

* clean up if-statements

* address review comments

* add test for vae tiling

* update docs

* update docs

* update docstrings

* add modeling test for cogvideox transformer

* make style
parent 5ffbe14c
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
# CogVideoX # CogVideoX
<!-- TODO: update paper with ArXiv link when ready. --> [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
The abstract from the paper is: The abstract from the paper is:
...@@ -43,43 +41,42 @@ from diffusers import CogVideoXPipeline ...@@ -43,43 +41,42 @@ from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
``` ```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
```python ```python
pipeline.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
``` ```
Finally, compile the components and run inference: Finally, compile the components and run inference:
```python ```python
pipeline.transformer = torch.compile(pipeline.transformer) pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works very well with long and well-described prompts # CogVideoX works well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
``` ```
The [benchmark](TODO: link) results on an 80GB A100 machine are: The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
``` ```
Without torch.compile(): Average inference time: TODO seconds. Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: TODO seconds. With torch.compile(): Average inference time: 76.27 seconds.
``` ```
### Memory optimization
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
## CogVideoXPipeline ## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline [[autodoc]] CogVideoXPipeline
......
...@@ -37,13 +37,20 @@ class CogVideoXBlock(nn.Module): ...@@ -37,13 +37,20 @@ class CogVideoXBlock(nn.Module):
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters: Parameters:
dim (`int`): The number of channels in the input and output. dim (`int`):
num_attention_heads (`int`): The number of heads to use for multi-head attention. The number of channels in the input and output.
attention_head_dim (`int`): The number of channels in each head. num_attention_heads (`int`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. The number of heads to use for multi-head attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_head_dim (`int`):
attention_bias (: The number of channels in each head.
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`): qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention. Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
...@@ -147,36 +154,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -147,36 +154,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (`int`, defaults to `30`):
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. The number of heads to use for multi-head attention.
in_channels (`int`, *optional*): attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input. The number of channels in the input.
out_channels (`int`, *optional*): out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output. The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. flip_sin_to_cos (`bool`, defaults to `True`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. Whether to flip the sin to cos in the time embedding.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. time_embed_dim (`int`, defaults to `512`):
attention_bias (`bool`, *optional*): Output dimension of timestep embeddings.
Configure if the `TransformerBlocks` attention should contain a bias parameter. text_embed_dim (`int`, defaults to `4096`):
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). Input dimension of text embeddings from the text encoder.
This is fixed during training since it is used to learn a number of position embeddings. num_layers (`int`, defaults to `30`):
patch_size (`int`, *optional*): The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer. The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. temporal_compression_ratio (`int`, defaults to `4`):
num_embeds_ada_norm ( `int`, *optional*): The compression ratio across the temporal dimension. See documentation for `sample_frames`.
The number of diffusion steps used during training. Pass if at least one of the norm_layers is max_text_seq_length (`int`, defaults to `226`):
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are The maximum sequence length of the input text embeddings.
added to the hidden states. During inference, you can denoise for up to but not more steps than activation_fn (`str`, defaults to `"gelu-approximate"`):
`num_embeds_ada_norm`. Activation function to use in feed-forward.
norm_type (`str`, *optional*, defaults to `"layer_norm"`): timestep_activation_fn (`str`, defaults to `"silu"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers. Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. norm_eps (`float`, defaults to `1e-5`):
caption_channels (`int`, *optional*): The epsilon value to use in normalization layers.
The number of channels in the caption embeddings. spatial_interpolation_scale (`float`, defaults to `1.875`):
video_length (`int`, *optional*): Scaling factor to apply in 3D positional embeddings across spatial dimensions.
The number of frames in the video-like data. temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -186,7 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -186,7 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self, self,
num_attention_heads: int = 30, num_attention_heads: int = 30,
attention_head_dim: int = 64, attention_head_dim: int = 64,
in_channels: Optional[int] = 16, in_channels: int = 16,
out_channels: Optional[int] = 16, out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
...@@ -304,7 +328,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -304,7 +328,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :] hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 5. Transformer blocks # 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -331,11 +355,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -331,11 +355,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm_final(hidden_states) hidden_states = self.norm_final(hidden_states)
# 6. Final block # 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify # 6. Unpatchify
p = self.config.patch_size p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
......
...@@ -332,20 +332,11 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -332,20 +332,11 @@ class CogVideoXPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def decode_latents(self, latents: torch.Tensor, num_seconds: int): def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
frames = [] frames = self.vae.decode(latents).sample
for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames)
self.vae.clear_fake_context_parallel_cache()
frames = torch.cat(frames, dim=2)
return frames return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -438,8 +429,7 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -438,8 +429,7 @@ class CogVideoXPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: int = 480,
width: int = 720, width: int = 720,
num_frames: int = 48, num_frames: int = 49,
fps: int = 8,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
guidance_scale: float = 6, guidance_scale: float = 6,
...@@ -534,9 +524,10 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -534,9 +524,10 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
assert ( if num_frames > 49:
num_frames <= 48 and num_frames % fps == 0 and fps == 8 raise ValueError(
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
...@@ -593,7 +584,6 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -593,7 +584,6 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
latent_channels, latent_channels,
...@@ -673,7 +663,7 @@ class CogVideoXPipeline(DiffusionPipeline): ...@@ -673,7 +663,7 @@ class CogVideoXPipeline(DiffusionPipeline):
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents, num_frames // fps) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
video = latents video = latents
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
...@@ -125,11 +125,6 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -125,11 +125,6 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Cannot reduce because convolution kernel becomes bigger than sample # Cannot reduce because convolution kernel becomes bigger than sample
"height": 16, "height": 16,
"width": 16, "width": 16,
# TODO(aryan): improve this
# Cannot make this lower due to assert condition in pipeline at the moment.
# The reason why 8 can't be used here is due to how context-parallel cache works where the first
# second of video is decoded from latent frames (0, 3) instead of [(0, 2), (2, 3)]. If 8 is used,
# the number of output frames that you get are 5.
"num_frames": 8, "num_frames": 8,
"max_sequence_length": 16, "max_sequence_length": 16,
"output_type": "pt", "output_type": "pt",
...@@ -148,8 +143,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -148,8 +143,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16) expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10) self.assertLessEqual(max_diff, 1e10)
...@@ -250,6 +245,36 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -250,6 +245,36 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"Attention slicing should not affect the inference results", "Attention slicing should not affect the inference results",
) )
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment