Unverified Commit 2dad462d authored by zR's avatar zR Committed by GitHub
Browse files

Add CogVideoX text-to-video generation model (#9082)



* add CogVideoX

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent e3568d14
......@@ -239,6 +239,8 @@
title: VQModel
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/stable_cascade_unet
......@@ -263,6 +265,8 @@
title: FluxTransformer2DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/transformer_temporal
......@@ -302,6 +306,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
......
......@@ -22,6 +22,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
## Supported pipelines
- [`CogVideoXPipeline`]
- [`StableDiffusionPipeline`]
- [`StableDiffusionImg2ImgPipeline`]
- [`StableDiffusionInpaintPipeline`]
......@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`UNet2DConditionModel`]
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`AutoencoderKLCogVideoX`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]
- [`FluxTransformer2DModel`]
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLCogVideoX
The 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLCogVideoX
vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16).to("cuda")
```
## AutoencoderKLCogVideoX
[[autodoc]] AutoencoderKLCogVideoX
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# CogVideoXTransformer3DModel
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import CogVideoXTransformer3DModel
vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## CogVideoXTransformer3DModel
[[autodoc]] CogVideoXTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# CogVideoX
<!-- TODO: update paper with ArXiv link when ready. -->
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
The abstract from the paper is:
*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
First, load the pipeline:
```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```
Finally, compile the components and run inference:
```python
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works very well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```
The [benchmark](TODO: link) results on an 80GB A100 machine are:
```
Without torch.compile(): Average inference time: TODO seconds.
With torch.compile(): Average inference time: TODO seconds.
```
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
- all
- __call__
## CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
import argparse
from typing import Any, Dict
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
to_v_key = key.replace("query_key_value", "to_v")
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
state_dict[to_q_key] = to_q
state_dict[to_k_key] = to_k
state_dict[to_v_key] = to_v
state_dict.pop(key)
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
if "query" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
elif "key" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
state_dict[new_key] = state_dict.pop(key)
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
weights_or_biases = state_dict[key].chunk(12, dim=0)
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
state_dict[norm1_key] = norm1_weights_or_biases
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
state_dict[norm2_key] = norm2_weights_or_biases
state_dict.pop(key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
replace_layer_index = 4 - 1 - layer_index
key_split[1] = "up_blocks"
key_split[2] = str(replace_layer_index)
new_key = ".".join(key_split)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
"attention": "attn1",
"mlp": "ff.net",
"dense_h_to_4h": "0.proj",
"dense_4h_to_h": "2",
".layers": "",
"dense": "to_out.0",
"input_layernorm": "norm1.norm",
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"nin_shortcut": "conv_shortcut",
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
TOKENIZER_MAX_LENGTH = 226
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(ckpt_path: str):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
}
)
pipe = CogVideoXPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
......@@ -79,9 +79,11 @@ else:
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
......@@ -155,6 +157,8 @@ else:
[
"AmusedScheduler",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
......@@ -248,6 +252,7 @@ else:
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXPipeline",
"CycleDiffusionPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
......@@ -535,9 +540,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
......@@ -608,6 +615,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
......@@ -682,6 +691,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline,
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXPipeline,
CycleDiffusionPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
......
......@@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
......@@ -41,6 +42,7 @@ if is_torch_available():
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
......@@ -77,6 +79,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
......@@ -92,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modeling_utils import ModelMixin
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
......
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d):
"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
# Set to 2GB, suitable for CuDNN
if memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2)
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super().forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super().forward(input)
class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
in_channels (int): Number of channels in the input tensor.
out_channels (int): Number of output channels.
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
stride (int, optional): Stride of the convolution. Default is 1.
dilation (int, optional): Dilation rate of the convolution. Default is 1.
pad_mode (str, optional): Padding mode. Default is "constant".
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: int = 1,
dilation: int = 1,
pad_mode: str = "constant",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
kernel_size = self.time_kernel_size
if kernel_size == 1:
return inputs
inputs = inputs.transpose(0, dim)
if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
inputs = inputs.transpose(0, dim).contiguous()
return inputs
def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
input_parallel = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class CogVideoXSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
to 3D-video like data.
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
groups: int = 32,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = F.interpolate(z_first, size=f_first_size)
z_rest = F.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = F.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class CogVideoXResnetBlock3D(nn.Module):
r"""
A 3D ResNet block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (Optional[int], optional):
Number of output channels. If None, defaults to `in_channels`. Default is None.
dropout (float, optional): Dropout rate. Default is 0.0.
temb_channels (int, optional): Number of time embedding channels. Default is 512.
groups (int, optional): Number of groups for group normalization. Default is 32.
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
non_linearity (str, optional): Activation function to use. Default is "swish".
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
eps: float = 1e-6,
non_linearity: str = "swish",
conv_shortcut: bool = False,
spatial_norm_dim: Optional[int] = None,
pad_mode: str = "first",
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
self.use_conv_shortcut = conv_shortcut
if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(
f_channels=in_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.norm2 = CogVideoXSpatialNorm3D(
f_channels=out_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.conv1 = CogVideoXCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
)
if temb_channels > 0:
self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = CogVideoXCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CogVideoXCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
)
else:
self.conv_shortcut = CogVideoXSafeConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
)
def forward(
self,
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs
if zq is not None:
hidden_states = self.norm1(hidden_states, zq)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
hidden_states = self.norm2(hidden_states, zq)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states
class CogVideoXDownBlock3D(nn.Module):
r"""
A downsampling block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_padding: int = 0,
compress_time: bool = False,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channel = in_channels if i == 0 else out_channels
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channel,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if add_downsample:
self.downsamplers = nn.ModuleList(
[
CogVideoXDownsample3D(
out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class CogVideoXMidBlock3D(nn.Module):
r"""
A middle block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
spatial_norm_dim: Optional[int] = None,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for _ in range(num_layers):
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
spatial_norm_dim=spatial_norm_dim,
non_linearity=resnet_act_fn,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq)
return hidden_states
class CogVideoXUpBlock3D(nn.Module):
r"""
An upsampling block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
"""
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
spatial_norm_dim: int = 16,
add_upsample: bool = True,
upsample_padding: int = 1,
compress_time: bool = False,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channel = in_channels if i == 0 else out_channels
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channel,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
spatial_norm_dim=spatial_norm_dim,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList(
[
CogVideoXUpsample3D(
out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class CogVideoXEncoder3D(nn.Module):
r"""
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
dropout: float = 0.0,
pad_mode: str = "first",
temporal_compression_ratio: float = 4,
):
super().__init__()
# log2 of temporal_compress_times
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
self.down_blocks = nn.ModuleList([])
# down blocks
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
if down_block_type == "CogVideoXDownBlock3D":
down_block = CogVideoXDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=0,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
compress_time=compress_time,
)
else:
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
self.down_blocks.append(down_block)
# mid block
self.mid_block = CogVideoXMidBlock3D(
in_channels=block_out_channels[-1],
temb_channels=0,
dropout=dropout,
num_layers=2,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
pad_mode=pad_mode,
)
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CogVideoXCausalConv3d(
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
)
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Down
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, temb, None
)
# 2. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, None
)
else:
# 1. Down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states, temb, None)
# 2. Mid
hidden_states = self.mid_block(hidden_states, temb, None)
# 3. Post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class CogVideoXDecoder3D(nn.Module):
r"""
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 16,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
dropout: float = 0.0,
pad_mode: str = "first",
temporal_compression_ratio: float = 4,
):
super().__init__()
reversed_block_out_channels = list(reversed(block_out_channels))
self.conv_in = CogVideoXCausalConv3d(
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
)
# mid block
self.mid_block = CogVideoXMidBlock3D(
in_channels=reversed_block_out_channels[0],
temb_channels=0,
num_layers=2,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
spatial_norm_dim=in_channels,
pad_mode=pad_mode,
)
# up blocks
self.up_blocks = nn.ModuleList([])
output_channel = reversed_block_out_channels[0]
temporal_compress_level = int(np.log2(temporal_compression_ratio))
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
if up_block_type == "CogVideoXUpBlock3D":
up_block = CogVideoXUpBlock3D(
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=0,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
spatial_norm_dim=in_channels,
add_upsample=not is_final_block,
compress_time=compress_time,
pad_mode=pad_mode,
)
prev_output_channel = output_channel
else:
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
self.up_blocks.append(up_block)
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
self.conv_act = nn.SiLU()
self.conv_out = CogVideoXCausalConv3d(
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
)
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, sample
)
# 2. Up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, temb, sample
)
else:
# 1. Mid
hidden_states = self.mid_block(hidden_states, temb, sample)
# 2. Up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb, sample)
# 3. Post-process
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXResnetBlock3D"]
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: Tuple[str] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_size: int = 256,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
):
super().__init__()
self.encoder = CogVideoXEncoder3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_eps=norm_eps,
norm_num_groups=norm_num_groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.decoder = CogVideoXDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_eps=norm_eps,
norm_num_groups=norm_num_groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
self.use_slicing = False
self.use_tiling = False
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec,)
return dec
......@@ -285,6 +285,74 @@ class KDownsample2D(nn.Module):
return F.conv2d(inputs, weight, stride=2)
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `2`):
Stride of the convolution.
padding (`int`, defaults to `0`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
def downsample_2d(
hidden_states: torch.Tensor,
kernel: Optional[torch.Tensor] = None,
......
......@@ -78,6 +78,53 @@ def get_timestep_embedding(
return emb
def get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Args:
embed_dim (`int`):
spatial_size (`int` or `Tuple[int, int]`):
temporal_size (`int`):
spatial_interpolation_scale (`float`, defaults to 1.0):
temporal_interpolation_scale (`float`, defaults to 1.0):
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
# 1. Spatial
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# 2. Temporal
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# 3. Concat
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
......@@ -287,6 +334,46 @@ class LuminaPatchEmbed(nn.Module):
)
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
......
......@@ -34,19 +34,53 @@ class AdaLayerNorm(nn.Module):
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(self, embedding_dim: int, num_embeddings: int):
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
......@@ -321,6 +355,30 @@ class LuminaLayerNormContinuous(nn.Module):
return x
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
......
......@@ -3,6 +3,7 @@ from ...utils import is_torch_available
if is_torch_available():
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
......
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
text_length = norm_encoder_hidden_states.size(1)
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
# them in cross-attention individually
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input.
out_channels (`int`, *optional*):
The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
patch_size (`int`, *optional*):
The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states. During inference, you can denoise for up to but not more steps than
`num_embeds_ada_norm`.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
caption_channels (`int`, *optional*):
The number of channels in the caption embeddings.
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: Optional[int] = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
hidden_states = self.norm_final(hidden_states)
# 6. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -348,6 +348,70 @@ class KUpsample2D(nn.Module):
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `1`):
Stride of the convolution.
padding (`int`, defaults to `1`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = F.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
def upfirdn2d_native(
tensor: torch.Tensor,
kernel: torch.Tensor,
......
......@@ -132,6 +132,7 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
......@@ -451,6 +452,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
class CogVideoXPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = []
for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames)
self.vae.clear_fake_context_parallel_cache()
frames = torch.cat(frames, dim=2)
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 48,
fps: int = 8,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
assert (
num_frames <= 48 and num_frames % fps == 0 and fps == 8
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents, num_frames // fps)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
......@@ -43,12 +43,14 @@ else:
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
_import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
_import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
_import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
......@@ -141,12 +143,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment