Unverified Commit 3b283061 authored by Yuxuan.Zhang's avatar Yuxuan.Zhang Committed by GitHub
Browse files

CogVideoX 1.5 (#9877)



* CogVideoX1_1PatchEmbed test

* 1360 * 768

* refactor

* make style

* update docs

* add modeling tests for cogvideox 1.5

* update

* make fix-copies

* add ofs embed(for convert)

* add ofs embed(for convert)

* more resolution for cogvideox1.5-5b-i2v

* use even number of latent frames only

* update pipeline implementations

* make style

* set patch_size_t as None by default

* #skip frames 0

* refactor

* make style

* update docs

* fix ofs_embed

* update docs

* invert_scale_latents

* update

* fix

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/cogvideox_transformer_3d.py

* update conversion script

* remove copied from

* fix test

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent c3c94fe7
......@@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.
There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|---|---|
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
There are two official CogVideoX checkpoints available for image-to-video.
| checkpoints | recommended inference dtype |
|---|---|
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
For the CogVideoX 1.5 series:
- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
| checkpoints | recommended inference dtype |
|---|---|
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
## Inference
......
......@@ -80,6 +80,8 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"ofs_embed.0": "ofs_embedding.linear_1",
"ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
......@@ -140,6 +142,7 @@ def convert_transformer(
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
init_kwargs: Dict[str, Any],
):
PREFIX_KEY = "model.diffusion_model."
......@@ -149,7 +152,9 @@ def convert_transformer(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=i2v,
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
**init_kwargs,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
......@@ -163,13 +168,18 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
init_kwargs = {"scaling_factor": scaling_factor}
if version == "1.5":
init_kwargs.update({"invert_scale_latents": True})
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
......@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
return vae
def get_transformer_init_kwargs(version: str):
if version == "1.0":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": None,
"patch_bias": True,
"sample_height": 480 // vae_scale_factor_spatial,
"sample_width": 720 // vae_scale_factor_spatial,
"sample_frames": 49,
}
elif version == "1.5":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": 2,
"patch_bias": False,
"sample_height": 300,
"sample_width": 300,
"sample_frames": 81,
}
else:
raise ValueError("Unsupported version of CogVideoX.")
return init_kwargs
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
......@@ -202,6 +240,12 @@ def get_args():
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
......@@ -214,7 +258,18 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument(
"--i2v",
action="store_true",
default=False,
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
)
parser.add_argument(
"--version",
choices=["1.0", "1.5"],
default="1.0",
help="Which version of CogVideoX to use for initializing default modeling parameters.",
)
return parser.parse_args()
......@@ -230,6 +285,7 @@ if __name__ == "__main__":
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
......@@ -237,14 +293,19 @@ if __name__ == "__main__":
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
init_kwargs,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
# Keep VAE in float32 for better quality
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
......@@ -276,11 +337,6 @@ if __name__ == "__main__":
scheduler=scheduler,
)
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
......
......@@ -1057,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
invert_scale_latents: bool = False,
):
super().__init__()
......
......@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
......@@ -355,6 +356,7 @@ class CogVideoXPatchEmbed(nn.Module):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
......@@ -366,9 +368,15 @@ class CogVideoXPatchEmbed(nn.Module):
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if patch_size_t is None:
# CogVideoX 1.0 checkpoints
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings:
......@@ -407,12 +415,24 @@ class CogVideoXPatchEmbed(nn.Module):
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
batch_size, num_frames, channels, height, width = image_embeds.shape
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
......@@ -497,7 +517,14 @@ class CogView3PlusPatchEmbed(nn.Module):
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
embed_dim,
crops_coords,
grid_size,
temporal_size,
theta: int = 10000,
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
......@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
grid_type (`str`):
Whether to use "linspace" or "slice" to compute grids.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
if grid_type == "linspace":
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")
# Compute dimensions for each axis
dim_t = embed_dim // 4
......@@ -559,6 +599,12 @@ def get_3d_rotary_pos_embed(
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
if grid_type == "slice":
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
......
......@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
ofs_embed_dim (`int`, defaults to `512`):
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
......@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
Whether to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
......@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
Whether to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
......@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
ofs_embed_dim: Optional[int] = None,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
......@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
patch_size_t: Optional[int] = None,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
......@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
......@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
......@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. Time embeddings
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.ofs_proj = None
self.ofs_embedding = None
if ofs_embed_dim:
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
self.ofs_embedding = TimestepEmbedding(
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
) # same as time embeddings, for ofs
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
......@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
if patch_size_t is None:
# For CogVideox 1.0
output_dim = patch_size * patch_size * out_channels
else:
# For CogVideoX 1.5
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = nn.Linear(inner_dim, output_dim)
self.gradient_checkpointing = False
......@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
......@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None:
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
......@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
p_t = self.config.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
......
......@@ -442,8 +442,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
......@@ -452,7 +457,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)
freqs_cos = freqs_cos.to(device=device)
......@@ -481,9 +486,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 49,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
......@@ -583,14 +588,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
......@@ -640,7 +644,16 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
# 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
......@@ -730,6 +743,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
progress_bar.update()
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
......
......@@ -488,8 +488,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
......@@ -498,7 +503,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)
freqs_cos = freqs_cos.to(device=device)
......@@ -528,8 +533,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
control_video: Optional[List[Image.Image]] = None,
height: int = 480,
width: int = 720,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
......@@ -634,6 +639,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if control_video is not None and isinstance(control_video[0], Image.Image):
control_video = [control_video]
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
......@@ -660,9 +672,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else:
batch_size = prompt_embeds.shape[0]
if control_video is not None and isinstance(control_video[0], Image.Image):
control_video = [control_video]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
......@@ -688,9 +697,18 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
# 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
raise ValueError(
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
f"contains {latent_frames=}, which is not divisible."
)
latent_channels = self.transformer.config.in_channels // 2
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
......
......@@ -367,6 +367,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
width // self.vae_scale_factor_spatial,
)
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
if self.transformer.config.patch_size_t is not None:
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
image = image.unsqueeze(2) # [B, C, F, H, W]
if isinstance(generator, list):
......@@ -377,7 +381,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
image_latents = self.vae_scaling_factor_image * image_latents
if not self.vae.config.invert_scale_latents:
image_latents = self.vae_scaling_factor_image * image_latents
else:
# This is awkward but required because the CogVideoX team forgot to multiply the
# scaling factor during training :)
image_latents = 1 / self.vae_scaling_factor_image * image_latents
padding_shape = (
batch_size,
......@@ -386,9 +396,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_latents = torch.cat([first_frame, image_latents], dim=1)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
......@@ -512,7 +528,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings(
self,
height: int,
......@@ -522,18 +537,38 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t
if p_t is None:
# CogVideoX 1.0 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
......@@ -562,8 +597,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
image: PipelineImageInput,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
......@@ -666,14 +701,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
......@@ -726,6 +760,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
image = self.video_processor.preprocess(image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype
)
......@@ -754,6 +797,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
else None
)
# 8. Create ofs embeds if required
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......@@ -778,6 +824,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
......@@ -823,6 +870,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
progress_bar.update()
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
......
......@@ -518,8 +518,13 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
......@@ -528,7 +533,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)
freqs_cos = freqs_cos.to(device=device)
......@@ -558,8 +563,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
video: List[Image.Image] = None,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
strength: float = 0.8,
......@@ -662,6 +667,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = len(video) if latents is None else latents.size(1)
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
......@@ -717,6 +726,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
raise ValueError(
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
f"contains {latent_frames=}, which is not divisible."
)
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=prompt_embeds.dtype)
......
......@@ -76,6 +76,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"patch_size_t": None,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
......@@ -85,3 +86,63 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"patch_size_t": 2,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
"use_rotary_positional_embeddings": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment