Unverified Commit 3b283061 authored by Yuxuan.Zhang's avatar Yuxuan.Zhang Committed by GitHub
Browse files

CogVideoX 1.5 (#9877)



* CogVideoX1_1PatchEmbed test

* 1360 * 768

* refactor

* make style

* update docs

* add modeling tests for cogvideox 1.5

* update

* make fix-copies

* add ofs embed(for convert)

* add ofs embed(for convert)

* more resolution for cogvideox1.5-5b-i2v

* use even number of latent frames only

* update pipeline implementations

* make style

* set patch_size_t as None by default

* #skip frames 0

* refactor

* make style

* update docs

* fix ofs_embed

* update docs

* invert_scale_latents

* update

* fix

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/cogvideox_transformer_3d.py

* update conversion script

* remove copied from

* fix test

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent c3c94fe7
...@@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m ...@@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines: There are three official CogVideoX checkpoints for text-to-video and video-to-video.
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`. | checkpoints | recommended inference dtype |
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`. |---|---|
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
There is one model available that can be used with the image-to-video CogVideoX pipeline: | [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. | [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): There are two official CogVideoX checkpoints available for image-to-video.
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. | checkpoints | recommended inference dtype |
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. |---|---|
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
For the CogVideoX 1.5 series:
- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
| checkpoints | recommended inference dtype |
|---|---|
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
## Inference ## Inference
......
...@@ -80,6 +80,8 @@ TRANSFORMER_KEYS_RENAME_DICT = { ...@@ -80,6 +80,8 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"post_attn1_layernorm": "norm2.norm", "post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1", "time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2", "time_embed.2": "time_embedding.linear_2",
"ofs_embed.0": "ofs_embedding.linear_1",
"ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed", "mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out", "mixins.final_layer.linear": "proj_out",
...@@ -140,6 +142,7 @@ def convert_transformer( ...@@ -140,6 +142,7 @@ def convert_transformer(
use_rotary_positional_embeddings: bool, use_rotary_positional_embeddings: bool,
i2v: bool, i2v: bool,
dtype: torch.dtype, dtype: torch.dtype,
init_kwargs: Dict[str, Any],
): ):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
...@@ -149,7 +152,9 @@ def convert_transformer( ...@@ -149,7 +152,9 @@ def convert_transformer(
num_layers=num_layers, num_layers=num_layers,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings, use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=i2v, ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
**init_kwargs,
).to(dtype=dtype) ).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
...@@ -163,13 +168,18 @@ def convert_transformer( ...@@ -163,13 +168,18 @@ def convert_transformer(
if special_key not in key: if special_key not in key:
continue continue
handler_fn_inplace(key, original_state_dict) handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True) transformer.load_state_dict(original_state_dict, strict=True)
return transformer return transformer
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
init_kwargs = {"scaling_factor": scaling_factor}
if version == "1.5":
init_kwargs.update({"invert_scale_latents": True})
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
...@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): ...@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
return vae return vae
def get_transformer_init_kwargs(version: str):
if version == "1.0":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": None,
"patch_bias": True,
"sample_height": 480 // vae_scale_factor_spatial,
"sample_width": 720 // vae_scale_factor_spatial,
"sample_frames": 49,
}
elif version == "1.5":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": 2,
"patch_bias": False,
"sample_height": 300,
"sample_width": 300,
"sample_frames": 81,
}
else:
raise ValueError("Unsupported version of CogVideoX.")
return init_kwargs
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -202,6 +240,12 @@ def get_args(): ...@@ -202,6 +240,12 @@ def get_args():
parser.add_argument( parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
) )
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42 # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
...@@ -214,7 +258,18 @@ def get_args(): ...@@ -214,7 +258,18 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") parser.add_argument(
"--i2v",
action="store_true",
default=False,
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
)
parser.add_argument(
"--version",
choices=["1.0", "1.5"],
default="1.0",
help="Which version of CogVideoX to use for initializing default modeling parameters.",
)
return parser.parse_args() return parser.parse_args()
...@@ -230,6 +285,7 @@ if __name__ == "__main__": ...@@ -230,6 +285,7 @@ if __name__ == "__main__":
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None: if args.transformer_ckpt_path is not None:
init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer( transformer = convert_transformer(
args.transformer_ckpt_path, args.transformer_ckpt_path,
args.num_layers, args.num_layers,
...@@ -237,14 +293,19 @@ if __name__ == "__main__": ...@@ -237,14 +293,19 @@ if __name__ == "__main__":
args.use_rotary_positional_embeddings, args.use_rotary_positional_embeddings,
args.i2v, args.i2v,
dtype, dtype,
init_kwargs,
) )
if args.vae_ckpt_path is not None: if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) # Keep VAE in float32 for better quality
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
text_encoder_id = "google/t5-v1_1-xxl" text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)
# Apparently, the conversion does not work anymore without this :shrug: # Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters(): for param in text_encoder.parameters():
param.data = param.data.contiguous() param.data = param.data.contiguous()
...@@ -276,11 +337,6 @@ if __name__ == "__main__": ...@@ -276,11 +337,6 @@ if __name__ == "__main__":
scheduler=scheduler, scheduler=scheduler,
) )
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here). # is either fp16/bf16 here).
......
...@@ -1057,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1057,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
force_upcast: float = True, force_upcast: float = True,
use_quant_conv: bool = False, use_quant_conv: bool = False,
use_post_quant_conv: bool = False, use_post_quant_conv: bool = False,
invert_scale_latents: bool = False,
): ):
super().__init__() super().__init__()
......
...@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module): ...@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
patch_size: int = 2, patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16, in_channels: int = 16,
embed_dim: int = 1920, embed_dim: int = 1920,
text_embed_dim: int = 4096, text_embed_dim: int = 4096,
...@@ -355,6 +356,7 @@ class CogVideoXPatchEmbed(nn.Module): ...@@ -355,6 +356,7 @@ class CogVideoXPatchEmbed(nn.Module):
super().__init__() super().__init__()
self.patch_size = patch_size self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.sample_height = sample_height self.sample_height = sample_height
self.sample_width = sample_width self.sample_width = sample_width
...@@ -366,9 +368,15 @@ class CogVideoXPatchEmbed(nn.Module): ...@@ -366,9 +368,15 @@ class CogVideoXPatchEmbed(nn.Module):
self.use_positional_embeddings = use_positional_embeddings self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings self.use_learned_positional_embeddings = use_learned_positional_embeddings
self.proj = nn.Conv2d( if patch_size_t is None:
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias # CogVideoX 1.0 checkpoints
) self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
self.text_proj = nn.Linear(text_embed_dim, embed_dim) self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings: if use_positional_embeddings or use_learned_positional_embeddings:
...@@ -407,12 +415,24 @@ class CogVideoXPatchEmbed(nn.Module): ...@@ -407,12 +415,24 @@ class CogVideoXPatchEmbed(nn.Module):
""" """
text_embeds = self.text_proj(text_embeds) text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape batch_size, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds) if self.patch_size_t is None:
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat( embeds = torch.cat(
[text_embeds, image_embeds], dim=1 [text_embeds, image_embeds], dim=1
...@@ -497,7 +517,14 @@ class CogView3PlusPatchEmbed(nn.Module): ...@@ -497,7 +517,14 @@ class CogView3PlusPatchEmbed(nn.Module):
def get_3d_rotary_pos_embed( def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True embed_dim,
crops_coords,
grid_size,
temporal_size,
theta: int = 10000,
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
RoPE for video tokens with 3D structure. RoPE for video tokens with 3D structure.
...@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed( ...@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension. The size of the temporal dimension.
theta (`float`): theta (`float`):
Scaling factor for frequency computation. Scaling factor for frequency computation.
grid_type (`str`):
Whether to use "linspace" or "slice" to compute grids.
Returns: Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
""" """
if use_real is not True: if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size if grid_type == "linspace":
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) start, stop = crops_coords
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) grid_size_h, grid_size_w = grid_size
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")
# Compute dimensions for each axis # Compute dimensions for each axis
dim_t = embed_dim // 4 dim_t = embed_dim // 4
...@@ -559,6 +599,12 @@ def get_3d_rotary_pos_embed( ...@@ -559,6 +599,12 @@ def get_3d_rotary_pos_embed(
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
if grid_type == "slice":
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
cos = combine_time_height_width(t_cos, h_cos, w_cos) cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin) sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin return cos, sin
......
...@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`): time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings. Output dimension of timestep embeddings.
ofs_embed_dim (`int`, defaults to `512`):
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
text_embed_dim (`int`, defaults to `4096`): text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder. Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`): num_layers (`int`, defaults to `30`):
...@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
dropout (`float`, defaults to `0.0`): dropout (`float`, defaults to `0.0`):
The dropout probability to use. The dropout probability to use.
attention_bias (`bool`, defaults to `True`): attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers. Whether to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`): sample_width (`int`, defaults to `90`):
The width of the input latents. The width of the input latents.
sample_height (`int`, defaults to `60`): sample_height (`int`, defaults to `60`):
...@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep_activation_fn (`str`, defaults to `"silu"`): timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings. Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers. Whether to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`): norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers. The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`): spatial_interpolation_scale (`float`, defaults to `1.875`):
...@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
time_embed_dim: int = 512, time_embed_dim: int = 512,
ofs_embed_dim: Optional[int] = None,
text_embed_dim: int = 4096, text_embed_dim: int = 4096,
num_layers: int = 30, num_layers: int = 30,
dropout: float = 0.0, dropout: float = 0.0,
...@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
sample_height: int = 60, sample_height: int = 60,
sample_frames: int = 49, sample_frames: int = 49,
patch_size: int = 2, patch_size: int = 2,
patch_size_t: Optional[int] = None,
temporal_compression_ratio: int = 4, temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226, max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate", activation_fn: str = "gelu-approximate",
...@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False, use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
): ):
super().__init__() super().__init__()
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
...@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 1. Patch embedding # 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed( self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels, in_channels=in_channels,
embed_dim=inner_dim, embed_dim=inner_dim,
text_embed_dim=text_embed_dim, text_embed_dim=text_embed_dim,
bias=True, bias=patch_bias,
sample_width=sample_width, sample_width=sample_width,
sample_height=sample_height, sample_height=sample_height,
sample_frames=sample_frames, sample_frames=sample_frames,
...@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
self.embedding_dropout = nn.Dropout(dropout) self.embedding_dropout = nn.Dropout(dropout)
# 2. Time embeddings # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.ofs_proj = None
self.ofs_embedding = None
if ofs_embed_dim:
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
self.ofs_embedding = TimestepEmbedding(
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
) # same as time embeddings, for ofs
# 3. Define spatio-temporal transformers blocks # 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
...@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
norm_eps=norm_eps, norm_eps=norm_eps,
chunk_dim=1, chunk_dim=1,
) )
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
if patch_size_t is None:
# For CogVideox 1.0
output_dim = patch_size * patch_size * out_channels
else:
# For CogVideoX 1.5
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = nn.Linear(inner_dim, output_dim)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
...@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
t_emb = t_emb.to(dtype=hidden_states.dtype) t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None:
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding # 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
...@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify # 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) p_t = self.config.patch_size_t
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer # remove `lora_scale` from each PEFT layer
......
...@@ -442,8 +442,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -442,8 +442,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
...@@ -452,7 +457,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -452,7 +457,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=base_num_frames,
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
...@@ -481,9 +486,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -481,9 +486,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self, self,
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: Optional[int] = None,
width: int = 720, width: Optional[int] = None,
num_frames: int = 49, num_frames: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
guidance_scale: float = 6, guidance_scale: float = 6,
...@@ -583,14 +588,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -583,14 +588,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
...@@ -640,7 +644,16 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -640,7 +644,16 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 5. Prepare latents. # 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
...@@ -730,6 +743,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -730,6 +743,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
......
...@@ -488,8 +488,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -488,8 +488,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
...@@ -498,7 +503,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -498,7 +503,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=base_num_frames,
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
...@@ -528,8 +533,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -528,8 +533,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
control_video: Optional[List[Image.Image]] = None, control_video: Optional[List[Image.Image]] = None,
height: int = 480, height: Optional[int] = None,
width: int = 720, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
guidance_scale: float = 6, guidance_scale: float = 6,
...@@ -634,6 +639,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -634,6 +639,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if control_video is not None and isinstance(control_video[0], Image.Image):
control_video = [control_video]
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
...@@ -660,9 +672,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -660,9 +672,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if control_video is not None and isinstance(control_video[0], Image.Image):
control_video = [control_video]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
...@@ -688,9 +697,18 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -688,9 +697,18 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 5. Prepare latents. # 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
raise ValueError(
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
f"contains {latent_frames=}, which is not divisible."
)
latent_channels = self.transformer.config.in_channels // 2 latent_channels = self.transformer.config.in_channels // 2
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
latent_channels, latent_channels,
......
...@@ -367,6 +367,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -367,6 +367,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
width // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial,
) )
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
if self.transformer.config.patch_size_t is not None:
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
image = image.unsqueeze(2) # [B, C, F, H, W] image = image.unsqueeze(2) # [B, C, F, H, W]
if isinstance(generator, list): if isinstance(generator, list):
...@@ -377,7 +381,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -377,7 +381,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
image_latents = self.vae_scaling_factor_image * image_latents
if not self.vae.config.invert_scale_latents:
image_latents = self.vae_scaling_factor_image * image_latents
else:
# This is awkward but required because the CogVideoX team forgot to multiply the
# scaling factor during training :)
image_latents = 1 / self.vae_scaling_factor_image * image_latents
padding_shape = ( padding_shape = (
batch_size, batch_size,
...@@ -386,9 +396,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -386,9 +396,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
height // self.vae_scale_factor_spatial, height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial,
) )
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
image_latents = torch.cat([image_latents, latent_padding], dim=1) image_latents = torch.cat([image_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_latents = torch.cat([first_frame, image_latents], dim=1)
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
...@@ -512,7 +528,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -512,7 +528,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self.transformer.unfuse_qkv_projections() self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings( def _prepare_rotary_positional_embeddings(
self, self,
height: int, height: int,
...@@ -522,18 +537,38 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -522,18 +537,38 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid( p = self.transformer.config.patch_size
(grid_height, grid_width), base_size_width, base_size_height p_t = self.transformer.config.patch_size_t
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( if p_t is None:
embed_dim=self.transformer.config.attention_head_dim, # CogVideoX 1.0 I2V
crops_coords=grid_crops_coords, base_size_width = self.transformer.config.sample_width // p
grid_size=(grid_height, grid_width), base_size_height = self.transformer.config.sample_height // p
temporal_size=num_frames,
) grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device) freqs_sin = freqs_sin.to(device=device)
...@@ -562,8 +597,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -562,8 +597,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
image: PipelineImageInput, image: PipelineImageInput,
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: Optional[int] = None,
width: int = 720, width: Optional[int] = None,
num_frames: int = 49, num_frames: int = 49,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
...@@ -666,14 +701,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -666,14 +701,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
...@@ -726,6 +760,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -726,6 +760,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 5. Prepare latents # 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
image = self.video_processor.preprocess(image, height=height, width=width).to( image = self.video_processor.preprocess(image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype device, dtype=prompt_embeds.dtype
) )
...@@ -754,6 +797,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -754,6 +797,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
else None else None
) )
# 8. Create ofs embeds if required
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
...@@ -778,6 +824,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -778,6 +824,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
...@@ -823,6 +870,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -823,6 +870,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
......
...@@ -518,8 +518,13 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -518,8 +518,13 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
...@@ -528,7 +533,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -528,7 +533,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=base_num_frames,
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
...@@ -558,8 +563,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -558,8 +563,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
video: List[Image.Image] = None, video: List[Image.Image] = None,
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: Optional[int] = None,
width: int = 720, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
strength: float = 0.8, strength: float = 0.8,
...@@ -662,6 +667,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -662,6 +667,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = len(video) if latents is None else latents.size(1)
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
...@@ -717,6 +726,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -717,6 +726,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 5. Prepare latents # 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
raise ValueError(
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
f"contains {latent_frames=}, which is not divisible."
)
if latents is None: if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width) video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=prompt_embeds.dtype) video = video.to(device=device, dtype=prompt_embeds.dtype)
......
...@@ -76,6 +76,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -76,6 +76,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_height": 8, "sample_height": 8,
"sample_frames": 8, "sample_frames": 8,
"patch_size": 2, "patch_size": 2,
"patch_size_t": None,
"temporal_compression_ratio": 4, "temporal_compression_ratio": 4,
"max_text_seq_length": 8, "max_text_seq_length": 8,
} }
...@@ -85,3 +86,63 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -85,3 +86,63 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"} expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"patch_size_t": 2,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
"use_rotary_positional_embeddings": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment