Commit 45467a6b authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feature]: add cogvideox t2v (#55)

parent c5048775
{
"seed": 42,
"text_len": 226,
"num_videos_per_prompt": 1,
"target_video_length": 81,
"num_inference_steps": 50,
"num_train_timesteps": 1000,
"timestep_spacing": "trailing",
"steps_offset": 0,
"latent_channels": 16,
"height": 768,
"width": 1360,
"vae_scale_factor_temporal": 4,
"vae_scale_factor_spatial": 8,
"vae_scaling_factor_image": 0.7,
"batch_size": 1,
"patch_size": 2,
"patch_size_t": 2,
"guidance_scale": 0,
"use_rotary_positional_embeddings": true,
"do_classifier_free_guidance": false,
"transformer_sample_width": 170,
"transformer_sample_height": 96,
"transformer_sample_frames": 81,
"transformer_attention_head_dim": 64,
"transformer_num_attention_heads": 48,
"transformer_temporal_compression_ratio": 4,
"transformer_temporal_interpolation_scale": 1.0,
"transformer_use_learned_positional_embeddings": false,
"transformer_spatial_interpolation_scale": 1.875,
"transformer_num_layers": 42,
"beta_schedule": "scaled_linear",
"scheduler_beta_start": 0.00085,
"scheduler_beta_end": 0.012,
"scheduler_set_alpha_to_one": true,
"scheduler_snr_shift_scale": 1.0,
"scheduler_rescale_betas_zero_snr": true,
"scheduler_prediction_type": "v_prediction",
"use_dynamic_cfg": true
}
...@@ -50,5 +50,5 @@ class LNWeight(LNWeightTemplate): ...@@ -50,5 +50,5 @@ class LNWeight(LNWeightTemplate):
super().__init__(weight_name, bias_name, eps) super().__init__(weight_name, bias_name, eps)
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[1],), self.weight, self.bias, self.eps) input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor return input_tensor
...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner ...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
from loguru import logger from loguru import logger
...@@ -36,7 +37,7 @@ def init_runner(config): ...@@ -36,7 +37,7 @@ def init_runner(config):
async def main(): async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
import torch
import os
from transformers import T5EncoderModel, T5Tokenizer
class T5EncoderModel_v1_1_xxl:
def __init__(self, config):
self.config = config
self.model = T5EncoderModel.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.bfloat16).to(torch.device("cuda"))
self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"), padding_side="right")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
def infer(self, texts, config):
text_inputs = self.tokenizer(
texts,
padding="max_length",
max_length=config.text_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to("cuda")
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(texts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, config.text_len - 1 : -1])
print(f"The following part of your input was truncated because `max_sequence_length` is set to {self.text_len} tokens: {removed_text}")
prompt_embeds = self.model(text_input_ids.to(torch.device("cuda")))[0]
return prompt_embeds
import torch
class CogvideoxPostInfer:
def __init__(self, config):
self.config = config
def ada_layernorm(self, weight_mm, weight_ln, x, temb):
temb = torch.nn.functional.silu(temb)
temb = weight_mm.apply(temb)
shift, scale = temb.chunk(2, dim=1)
x = weight_ln.apply(x) * (1 + scale) + shift
return x
def infer(self, weight, hidden_states, encoder_hidden_states, temb, infer_shapes):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
hidden_states = weight.norm_final.apply(hidden_states)
hidden_states = hidden_states[self.config.text_len :,]
hidden_states = self.ada_layernorm(weight.norm_out_linear, weight.norm_out_norm, hidden_states, temb=temb)
hidden_states = weight.proj_out.apply(hidden_states)
p = self.config["patch_size"]
p_t = self.config["patch_size_t"]
num_frames, _, height, width = infer_shapes
output = hidden_states.reshape((num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p)
output = output.permute(0, 4, 3, 1, 5, 2, 6).flatten(5, 6).flatten(3, 4).flatten(0, 1)
return output
import torch
from diffusers.models.embeddings import get_timestep_embedding, get_3d_sincos_pos_embed
class CogvideoxPreInfer:
def __init__(self, config):
self.config = config
self.use_positional_embeddings = not self.config.use_rotary_positional_embeddings
self.inner_dim = self.config.transformer_num_attention_heads * self.config.transformer_attention_head_dim
self.freq_shift = 0
self.flip_sin_to_cos = True
self.scale = 1
self.act = "silu"
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device):
post_patch_height = sample_height // self.config.patch_size
post_patch_width = sample_width // self.config.patch_size
post_time_compression_frames = (sample_frames - 1) // self.config.transformer_temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.config.transformer_spatial_interpolation_scale,
self.config.transformer_temporal_interpolation_scale,
device=device,
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = pos_embedding.new_zeros(1, self.config.text_len + num_patches, self.inner_dim, requires_grad=False)
joint_pos_embedding.data[:, self.config.text_len :].copy_(pos_embedding)
return joint_pos_embedding
def infer(self, weights, hidden_states, timestep, encoder_hidden_states):
t_emb = get_timestep_embedding(
timestep,
self.inner_dim,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.freq_shift,
scale=self.scale,
)
t_emb = t_emb.to(dtype=hidden_states.dtype)
sample = weights.time_embedding_linear_1.apply(t_emb)
sample = torch.nn.functional.silu(sample)
emb = weights.time_embedding_linear_2.apply(sample)
text_embeds = weights.patch_embed_text_proj.apply(encoder_hidden_states)
num_frames, channels, height, width = hidden_states.shape
infer_shapes = (num_frames, channels, height, width)
p = self.config.patch_size
p_t = self.config.patch_size_t
image_embeds = hidden_states.permute(0, 2, 3, 1)
image_embeds = image_embeds.reshape(num_frames // p_t, p_t, height // p, p, width // p, p, channels)
image_embeds = image_embeds.permute(0, 2, 4, 6, 1, 3, 5).flatten(3, 6).flatten(0, 2)
image_embeds = weights.patch_embed_proj.apply(image_embeds)
embeds = torch.cat([text_embeds, image_embeds], dim=0).contiguous()
if self.use_positional_embeddings or self.config.transformer_use_learned_positional_embeddings:
if self.config.transformer_use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.config.transformer_temporal_compression_ratio + 1
if self.config.transformer_sample_height != height or self.config.transformer_sample_width != width or self.config.transformer_sample_frames != pre_time_compression_frames:
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames, device=embeds.device)[0]
else:
pos_embedding = self.pos_embedding[0]
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
hidden_states = embeds
text_seq_length = encoder_hidden_states.shape[0]
encoder_hidden_states = hidden_states[:text_seq_length, :]
hidden_states = hidden_states[text_seq_length:, :]
return hidden_states, encoder_hidden_states, emb, infer_shapes
import torch
import torch.nn.functional as F
def apply_rotary_emb(x, freqs_cis, use_real=True, use_real_unbind_dim=-1):
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None]
sin = sin[None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(2)
return x_out.type_as(x)
class CogvideoxTransformerInfer:
def __init__(self, config):
self.config = config
self.attn_type = "torch_sdpa"
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, hidden_states, encoder_hidden_states, temb):
image_rotary_emb = self.scheduler.image_rotary_emb
for i in range(self.config.transformer_num_layers):
hidden_states, encoder_hidden_states = self.infer_block(
weights.blocks_weights[i],
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
return hidden_states, encoder_hidden_states
def cogvideox_norm1(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm1_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm1_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm1_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_norm2(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm2_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm2_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm2_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_attention(self, weights, hidden_states, encoder_hidden_states, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
query = weights.attn1_to_q.apply(hidden_states)
key = weights.attn1_to_k.apply(hidden_states)
value = weights.attn1_to_v.apply(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.config.transformer_num_attention_heads
query = query.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
key = key.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
value = value.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
query = weights.attn1_norm_q.apply(query)
key = weights.attn1_norm_k.apply(key)
query[:, text_seq_length:] = apply_rotary_emb(query[:, text_seq_length:], image_rotary_emb)
key[:, text_seq_length:] = apply_rotary_emb(key[:, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query[None], key[None], value[None], attn_mask=None, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(1, -1, self.config.transformer_num_attention_heads * head_dim)
hidden_states = hidden_states.squeeze(0)
hidden_states = weights.attn1_to_out.apply(hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split([text_seq_length, hidden_states.size(0) - text_seq_length], dim=0)
return hidden_states, encoder_hidden_states
def cogvideox_ff(self, weights, hidden_states):
hidden_states = weights.ff_net_0_proj.apply(hidden_states)
hidden_states = torch.nn.functional.gelu(hidden_states, approximate="tanh")
hidden_states = weights.ff_net_2_proj.apply(hidden_states)
return hidden_states
@torch.no_grad()
def infer_block(self, weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.cogvideox_norm1(weights, hidden_states, encoder_hidden_states, temb)
attn_hidden_states, attn_encoder_hidden_states = self.cogvideox_attention(
weights,
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.cogvideox_norm2(weights, hidden_states, encoder_hidden_states, temb)
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=0)
ff_output = self.cogvideox_ff(weights, norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[text_seq_length:,]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:text_seq_length,]
return hidden_states, encoder_hidden_states
import torch
from safetensors import safe_open
import os
import glob
import math
import json
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
class CogvideoxModel:
pre_weight_class = CogvideoxPreWeights
post_weight_class = CogvideoxPostWeights
transformer_weight_class = CogvideoxTransformerWeights
def __init__(self, config):
self.config = config
self.device = torch.device("cuda")
self._init_infer_class()
self._init_weights()
self._init_infer()
def _init_infer_class(self):
self.pre_infer_class = CogvideoxPreInfer
self.post_infer_class = CogvideoxPostInfer
self.transformer_infer_class = CogvideoxTransformerInfer
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
safetensors_pattern = os.path.join(self.config.model_path, "transformer", "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
with open(os.path.join(self.config.model_path, "transformer", "config.json"), "r") as f:
transformer_cfg = json.load(f)
# init weights
self.pre_weight = self.pre_weight_class(transformer_cfg)
self.transformer_weights = self.transformer_weight_class(transformer_cfg)
self.post_weight = self.post_weight_class(transformer_cfg)
# load weights
self.pre_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
text_encoder_output = inputs["text_encoder_output"]["context"]
do_classifier_free_guidance = self.config.guidance_scale > 1.0
latent_model_input = self.scheduler.latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer(
self.pre_weight,
latent_model_input[0],
timestep,
text_encoder_output[0],
)
hidden_states, encoder_hidden_states = self.transformer_infer.infer(
self.transformer_weights,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes)
noise_pred = noise_pred.float()
if self.config.use_dynamic_cfg: # True
self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2)
if do_classifier_free_guidance: # False
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxPostWeights:
def __init__(self, config, mm_type="Default"):
self.config = config
self.mm_type = mm_type
def load_weights(self, weight_dict):
self.norm_out_linear = MM_WEIGHT_REGISTER[self.mm_type]("norm_out.linear.weight", "norm_out.linear.bias")
self.proj_out = MM_WEIGHT_REGISTER[self.mm_type]("proj_out.weight", "proj_out.bias")
self.norm_final = LN_WEIGHT_REGISTER[self.mm_type]("norm_final.weight", "norm_final.bias")
self.norm_out_norm = LN_WEIGHT_REGISTER[self.mm_type]("norm_out.norm.weight", "norm_out.norm.bias", eps=1e-5)
self.weight_list = [self.norm_out_linear, self.proj_out, self.norm_final, self.norm_out_norm]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxPreWeights:
def __init__(self, config):
self.config = config
def load_weights(self, weight_dict):
self.time_embedding_linear_1 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_1.weight", "time_embedding.linear_1.bias")
self.time_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_2.weight", "time_embedding.linear_2.bias")
self.patch_embed_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.proj.weight", "patch_embed.proj.bias")
self.patch_embed_text_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.text_proj.weight", "patch_embed.text_proj.bias")
self.weight_list = [self.time_embedding_linear_1, self.time_embedding_linear_2, self.patch_embed_proj, self.patch_embed_text_proj]
for mm_weight in self.weight_list:
mm_weight.set_config(self.config)
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxTransformerWeights:
def __init__(self, config, task="t2v", mm_type="Default"):
self.config = config
self.task = task
self.mm_type = mm_type
self.init()
def init(self):
self.num_layers = self.config["num_layers"]
def load_weights(self, weight_dict):
self.blocks_weights = [CogVideoXBlock(i, self.task, self.mm_type) for i in range(self.num_layers)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
def to_cpu(self):
for block in self.blocks_weights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
block.to_cuda()
class CogVideoXBlock:
def __init__(self, block_index, task="t2v", mm_type="Default"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
def load_weights(self, weight_dict):
self.attn1_to_k = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_k.weight", f"transformer_blocks.{self.block_index}.attn1.to_k.bias")
self.attn1_to_q = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_q.weight", f"transformer_blocks.{self.block_index}.attn1.to_q.bias")
self.attn1_to_v = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_v.weight", f"transformer_blocks.{self.block_index}.attn1.to_v.bias")
self.attn1_to_out = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_out.0.weight", f"transformer_blocks.{self.block_index}.attn1.to_out.0.bias")
self.ff_net_0_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.0.proj.weight", f"transformer_blocks.{self.block_index}.ff.net.0.proj.bias")
self.ff_net_2_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.2.weight", f"transformer_blocks.{self.block_index}.ff.net.2.bias")
self.norm1_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.linear.weight", f"transformer_blocks.{self.block_index}.norm1.linear.bias")
self.norm2_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.linear.weight", f"transformer_blocks.{self.block_index}.norm2.linear.bias")
self.attn1_norm_k = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_k.weight", f"transformer_blocks.{self.block_index}.attn1.norm_k.bias")
self.attn1_norm_q = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_q.weight", f"transformer_blocks.{self.block_index}.attn1.norm_q.bias")
self.norm1_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.norm.weight", f"transformer_blocks.{self.block_index}.norm1.norm.bias", eps=1e-05)
self.norm2_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.norm.weight", f"transformer_blocks.{self.block_index}.norm2.norm.bias", eps=1e-05)
self.weight_list = [
self.attn1_to_k,
self.attn1_to_q,
self.attn1_to_v,
self.attn1_to_out,
self.ff_net_0_proj,
self.ff_net_2_proj,
self.norm1_linear,
self.norm2_linear,
self.attn1_norm_k,
self.attn1_norm_q,
self.norm1_norm,
self.norm2_norm,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from diffusers.utils import export_to_video
import imageio
import numpy as np
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder]
model = CogvideoxModel(self.config)
vae_model = CogvideoxVAE(self.config)
image_encoder = None
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self):
scheduler = CogvideoxXDPMScheduler(self.config)
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
n_prompt = config.get("negative_prompt", "")
context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def set_target_shape(self):
num_frames = self.config.target_video_length
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
additional_frames = 0
patch_size_t = self.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.config.vae_scale_factor_temporal
self.config.target_shape = (
self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
def save_video(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8)
writer.append_data(frame_np)
import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import numpy as np
from lightx2v.models.schedulers.scheduler import BaseScheduler
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
return alphas_bar
class CogvideoxXDPMScheduler(BaseScheduler):
def __init__(self, config):
self.config = config
self.set_timesteps()
self.generator = torch.Generator().manual_seed(config.seed)
self.noise_pred = None
if self.config.beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(self.config.scheduler_beta_start**0.5, self.config.scheduler_beta_end**0.5, self.config.num_train_timesteps, dtype=torch.float64) ** 2
else:
raise NotImplementedError(f"{self.config.beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(torch.device("cuda"))
# Modify: SNR shift following SD3
self.alphas_cumprod = self.alphas_cumprod / (self.config.scheduler_snr_shift_scale + (1 - self.config.scheduler_snr_shift_scale) * self.alphas_cumprod)
# Rescale for zero SNR
if self.config.scheduler_rescale_betas_zero_snr:
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if self.config.scheduler_set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def scale_model_input(self, sample, timestep=None):
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def set_timesteps(self):
if self.config.num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {self.config.num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.infer_steps = self.config.num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, self.config.num_inference_steps).round()[::-1].copy().astype(np.int64)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, self.infer_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'.")
self.timesteps = torch.Tensor(timesteps).to(torch.device("cuda")).int()
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding()
def prepare_latents(self, shape, dtype):
latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.init_noise_sigma
self.latents = latents
self.old_pred_original_sample = None
def prepare_guidance(self):
self.guidance_scale = self.config.guidance_scale
def prepare_rotary_pos_embedding(self):
grid_height = self.config.height // (self.config.vae_scale_factor_spatial * self.config.patch_size)
grid_width = self.config.width // (self.config.vae_scale_factor_spatial * self.config.patch_size)
p = self.config.patch_size
p_t = self.config.patch_size_t
base_size_width = self.config.transformer_sample_width // p
base_size_height = self.config.transformer_sample_height // p
num_frames = self.latents.size(1)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=torch.device("cuda"),
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.config.transformer_attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=torch.device("cuda"),
)
self.freqs_cos = freqs_cos
self.freqs_sin = freqs_sin
self.image_rotary_emb = (freqs_cos, freqs_sin) if self.config.use_rotary_positional_embeddings else None
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
h = lamb_next - lamb
if alpha_prod_t_back is not None:
lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
h_last = lamb - lamb_previous
r = h_last / h
return h, r, lamb, lamb_next
else:
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
if alpha_prod_t_back is not None:
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
return mult1, mult2, mult3, mult4
else:
return mult1, mult2
def step_post(self):
if self.infer_steps is None:
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
timestep = self.timesteps[self.step_index]
timestep_back = self.timesteps[self.step_index - 1] if self.step_index > 0 else None
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.infer_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.scheduler_prediction_type == "epsilon":
pred_original_sample = (self.latents - beta_prod_t ** (0.5) * self.noise_pred) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.scheduler_prediction_type == "sample":
pred_original_sample = self.noise_pred
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.scheduler_prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * self.latents - (beta_prod_t**0.5) * self.noise_pred
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"prediction_type given as {self.config.scheduler_prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`")
h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
prev_sample = mult[0] * self.latents - mult[1] * pred_original_sample + mult_noise * noise
if self.old_pred_original_sample is None or prev_timestep < 0:
# Save a network evaluation if all noise levels are 0 or on the first step
self.latents = prev_sample
self.old_pred_original_sample = pred_original_sample
else:
denoised_d = mult[2] * pred_original_sample - mult[3] * self.old_pred_original_sample
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise
self.latents = x_advanced
self.old_pred_original_sample = pred_original_sample
self.latents = self.latents.to(torch.bfloat16)
...@@ -211,7 +211,7 @@ def get_1d_rotary_pos_embed_riflex( ...@@ -211,7 +211,7 @@ def get_1d_rotary_pos_embed_riflex(
if isinstance(pos, int): if isinstance(pos, int):
pos = torch.arange(pos) pos = torch.arange(pos)
if isinstance(pos, np.ndarray): if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S] pos = torch.from_numpy(pos) # [S]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2]
...@@ -223,7 +223,7 @@ def get_1d_rotary_pos_embed_riflex( ...@@ -223,7 +223,7 @@ def get_1d_rotary_pos_embed_riflex(
freqs[k - 1] = 0.9 * 2 * torch.pi / L_test freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end === # === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] freqs = torch.outer(pos, freqs) # [S, D/2]
if use_real: if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
......
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.upsampling import CogVideoXUpsample3D
from diffusers.models.downsampling import CogVideoXDownsample3D
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d):
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
memory_count = (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
# Set to 2GB, suitable for CuDNN
if memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2)
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super().forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super().forward(input)
class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int`, defaults to `1`): Stride of the convolution.
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (`str`, defaults to `"constant"`): Padding mode.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: int = 1,
dilation: int = 1,
pad_mode: str = "constant",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
# TODO(aryan): configure calculation based on stride and dilation in the future.
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
time_pad = time_kernel_size - 1
height_pad = (height_kernel_size - 1) // 2
width_pad = (width_kernel_size - 1) // 2
self.pad_mode = pad_mode
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def fake_context_parallel_forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.pad_mode == "replicate":
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
else:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
if self.pad_mode == "replicate":
conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output, conv_cache
class CogVideoXSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
to 3D-video like data.
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
groups (`int`):
Number of groups to separate the channels into for group normalization.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
groups: int = 32,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = F.interpolate(z_first, size=f_first_size)
z_rest = F.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = F.interpolate(zq, size=f.shape[-3:])
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
norm_f = self.norm_layer(f)
new_f = norm_f * conv_y + conv_b
return new_f, new_conv_cache
class CogVideoXResnetBlock3D(nn.Module):
r"""
A 3D ResNet block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
eps: float = 1e-6,
non_linearity: str = "swish",
conv_shortcut: bool = False,
spatial_norm_dim: Optional[int] = None,
pad_mode: str = "first",
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
self.use_conv_shortcut = conv_shortcut
self.spatial_norm_dim = spatial_norm_dim
if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(
f_channels=in_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.norm2 = CogVideoXSpatialNorm3D(
f_channels=out_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.conv1 = CogVideoXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = CogVideoXCausalConv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CogVideoXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode)
else:
self.conv_shortcut = CogVideoXSafeConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(
self,
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states = inputs
if zq is not None:
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(inputs, conv_cache=conv_cache.get("conv_shortcut"))
else:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states, new_conv_cache
class CogVideoXDownBlock3D(nn.Module):
r"""
A downsampling block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
add_downsample (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_padding: int = 0,
compress_time: bool = False,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channel = in_channels if i == 0 else out_channels
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channel,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if add_downsample:
self.downsamplers = nn.ModuleList([CogVideoXDownsample3D(out_channels, out_channels, padding=downsample_padding, compress_time=compress_time)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXDownBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key))
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, new_conv_cache
class CogVideoXMidBlock3D(nn.Module):
r"""
A middle block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
spatial_norm_dim: Optional[int] = None,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for _ in range(num_layers):
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
spatial_norm_dim=spatial_norm_dim,
non_linearity=resnet_act_fn,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXMidBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key))
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key))
return hidden_states, new_conv_cache
class CogVideoXUpBlock3D(nn.Module):
r"""
An upsampling block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, defaults to `16`):
The dimension to use for spatial norm if it is to be used instead of group norm.
add_upsample (`bool`, defaults to `True`):
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
spatial_norm_dim: int = 16,
add_upsample: bool = True,
upsample_padding: int = 1,
compress_time: bool = False,
pad_mode: str = "first",
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channel = in_channels if i == 0 else out_channels
resnets.append(
CogVideoXResnetBlock3D(
in_channels=in_channel,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
spatial_norm_dim=spatial_norm_dim,
pad_mode=pad_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([CogVideoXUpsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
zq,
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key))
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states, new_conv_cache
class CogVideoXEncoder3D(nn.Module):
r"""
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
dropout: float = 0.0,
pad_mode: str = "first",
temporal_compression_ratio: float = 4,
):
super().__init__()
# log2 of temporal_compress_times
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
self.down_blocks = nn.ModuleList([])
# down blocks
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
if down_block_type == "CogVideoXDownBlock3D":
down_block = CogVideoXDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=0,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
compress_time=compress_time,
)
else:
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
self.down_blocks.append(down_block)
# mid block
self.mid_block = CogVideoXMidBlock3D(
in_channels=block_out_channels[-1],
temb_channels=0,
dropout=dropout,
num_layers=2,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
pad_mode=pad_mode,
)
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CogVideoXCausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
temb,
None,
conv_cache.get(conv_cache_key),
)
# 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
None,
conv_cache.get("mid_block"),
)
else:
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(hidden_states, temb, None, conv_cache.get(conv_cache_key))
# 2. Mid
hidden_states, new_conv_cache["mid_block"] = self.mid_block(hidden_states, temb, None, conv_cache=conv_cache.get("mid_block"))
# 3. Post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
class CogVideoXDecoder3D(nn.Module):
r"""
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 16,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
dropout: float = 0.0,
pad_mode: str = "first",
temporal_compression_ratio: float = 4,
):
super().__init__()
reversed_block_out_channels = list(reversed(block_out_channels))
self.conv_in = CogVideoXCausalConv3d(in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
# mid block
self.mid_block = CogVideoXMidBlock3D(
in_channels=reversed_block_out_channels[0],
temb_channels=0,
num_layers=2,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
spatial_norm_dim=in_channels,
pad_mode=pad_mode,
)
# up blocks
self.up_blocks = nn.ModuleList([])
output_channel = reversed_block_out_channels[0]
temporal_compress_level = int(np.log2(temporal_compression_ratio))
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
if up_block_type == "CogVideoXUpBlock3D":
up_block = CogVideoXUpBlock3D(
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=0,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
spatial_norm_dim=in_channels,
add_upsample=not is_final_block,
compress_time=compress_time,
pad_mode=pad_mode,
)
prev_output_channel = output_channel
else:
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
self.up_blocks.append(up_block)
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
self.conv_act = nn.SiLU()
self.conv_out = CogVideoXCausalConv3d(reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
new_conv_cache = {}
conv_cache = conv_cache or {}
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
temb,
sample,
conv_cache.get("mid_block"),
)
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
temb,
sample,
conv_cache.get(conv_cache_key),
)
else:
# 1. Mid
hidden_states, new_conv_cache["mid_block"] = self.mid_block(hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block"))
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = up_block(hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key))
# 3. Post-process
hidden_states, new_conv_cache["norm_out"] = self.norm_out(hidden_states, sample, conv_cache=conv_cache.get("norm_out"))
hidden_states = self.conv_act(hidden_states)
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXResnetBlock3D"]
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: Tuple[str] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_height: int = 480,
sample_width: int = 720,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
invert_scale_latents: bool = False,
):
super().__init__()
self.encoder = CogVideoXEncoder3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_eps=norm_eps,
norm_num_groups=norm_num_groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.decoder = CogVideoXDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_eps=norm_eps,
norm_num_groups=norm_num_groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
self.use_slicing = False
self.use_tiling = False
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand.
# If you decode X latent frames together, the number of output frames is:
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
#
# Example with num_latent_frames_batch_size = 2:
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 6 * 8 = 48 frames
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 1 * 9 + 5 * 8 = 49 frames
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
enc = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
dec = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = self.num_sample_frames_batch_size
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
import os
import glob
import torch # type: ignore
from safetensors import safe_open # type: ignore
from diffusers.video_processor import VideoProcessor # type: ignore
from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX
class CogvideoxVAE:
def __init__(self, config):
self.config = config
self.load()
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self, model_path):
safetensors_pattern = os.path.join(model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def load(self):
vae_path = os.path.join(self.config.model_path, "vae")
self.vae_config = AutoencoderKLCogVideoX.load_config(vae_path)
self.model = AutoencoderKLCogVideoX.from_config(self.vae_config)
vae_ckpt = self._load_ckpt(vae_path)
self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) # 8
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.model.load_state_dict(vae_ckpt)
self.model.to(torch.bfloat16).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@torch.no_grad()
def decode(self, latents, generator, config):
latents = latents.permute(0, 2, 1, 3, 4)
latents = 1 / self.config.vae_scaling_factor_image * latents
frames = self.model.decode(latents).sample
images = self.video_processor.postprocess_video(video=frames, output_type="pil")[0]
return images
...@@ -39,8 +39,9 @@ def set_config(args): ...@@ -39,8 +39,9 @@ def set_config(args):
model_config = json.load(f) model_config = json.load(f)
config.update(model_config) config.update(model_config)
if config.target_video_length % config.vae_stride[0] != 1: if config.task == "i2v":
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.") if config.target_video_length % config.vae_stride[0] != 1:
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1 logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment