Commit 45467a6b authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feature]: add cogvideox t2v (#55)

parent c5048775
{
"seed": 42,
"text_len": 226,
"num_videos_per_prompt": 1,
"target_video_length": 81,
"num_inference_steps": 50,
"num_train_timesteps": 1000,
"timestep_spacing": "trailing",
"steps_offset": 0,
"latent_channels": 16,
"height": 768,
"width": 1360,
"vae_scale_factor_temporal": 4,
"vae_scale_factor_spatial": 8,
"vae_scaling_factor_image": 0.7,
"batch_size": 1,
"patch_size": 2,
"patch_size_t": 2,
"guidance_scale": 0,
"use_rotary_positional_embeddings": true,
"do_classifier_free_guidance": false,
"transformer_sample_width": 170,
"transformer_sample_height": 96,
"transformer_sample_frames": 81,
"transformer_attention_head_dim": 64,
"transformer_num_attention_heads": 48,
"transformer_temporal_compression_ratio": 4,
"transformer_temporal_interpolation_scale": 1.0,
"transformer_use_learned_positional_embeddings": false,
"transformer_spatial_interpolation_scale": 1.875,
"transformer_num_layers": 42,
"beta_schedule": "scaled_linear",
"scheduler_beta_start": 0.00085,
"scheduler_beta_end": 0.012,
"scheduler_set_alpha_to_one": true,
"scheduler_snr_shift_scale": 1.0,
"scheduler_rescale_betas_zero_snr": true,
"scheduler_prediction_type": "v_prediction",
"use_dynamic_cfg": true
}
...@@ -50,5 +50,5 @@ class LNWeight(LNWeightTemplate): ...@@ -50,5 +50,5 @@ class LNWeight(LNWeightTemplate):
super().__init__(weight_name, bias_name, eps) super().__init__(weight_name, bias_name, eps)
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[1],), self.weight, self.bias, self.eps) input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor return input_tensor
...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner ...@@ -15,6 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
from loguru import logger from loguru import logger
...@@ -36,7 +37,7 @@ def init_runner(config): ...@@ -36,7 +37,7 @@ def init_runner(config):
async def main(): async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
import torch
import os
from transformers import T5EncoderModel, T5Tokenizer
class T5EncoderModel_v1_1_xxl:
def __init__(self, config):
self.config = config
self.model = T5EncoderModel.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.bfloat16).to(torch.device("cuda"))
self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"), padding_side="right")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
def infer(self, texts, config):
text_inputs = self.tokenizer(
texts,
padding="max_length",
max_length=config.text_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to("cuda")
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(texts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, config.text_len - 1 : -1])
print(f"The following part of your input was truncated because `max_sequence_length` is set to {self.text_len} tokens: {removed_text}")
prompt_embeds = self.model(text_input_ids.to(torch.device("cuda")))[0]
return prompt_embeds
import torch
class CogvideoxPostInfer:
def __init__(self, config):
self.config = config
def ada_layernorm(self, weight_mm, weight_ln, x, temb):
temb = torch.nn.functional.silu(temb)
temb = weight_mm.apply(temb)
shift, scale = temb.chunk(2, dim=1)
x = weight_ln.apply(x) * (1 + scale) + shift
return x
def infer(self, weight, hidden_states, encoder_hidden_states, temb, infer_shapes):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
hidden_states = weight.norm_final.apply(hidden_states)
hidden_states = hidden_states[self.config.text_len :,]
hidden_states = self.ada_layernorm(weight.norm_out_linear, weight.norm_out_norm, hidden_states, temb=temb)
hidden_states = weight.proj_out.apply(hidden_states)
p = self.config["patch_size"]
p_t = self.config["patch_size_t"]
num_frames, _, height, width = infer_shapes
output = hidden_states.reshape((num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p)
output = output.permute(0, 4, 3, 1, 5, 2, 6).flatten(5, 6).flatten(3, 4).flatten(0, 1)
return output
import torch
from diffusers.models.embeddings import get_timestep_embedding, get_3d_sincos_pos_embed
class CogvideoxPreInfer:
def __init__(self, config):
self.config = config
self.use_positional_embeddings = not self.config.use_rotary_positional_embeddings
self.inner_dim = self.config.transformer_num_attention_heads * self.config.transformer_attention_head_dim
self.freq_shift = 0
self.flip_sin_to_cos = True
self.scale = 1
self.act = "silu"
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device):
post_patch_height = sample_height // self.config.patch_size
post_patch_width = sample_width // self.config.patch_size
post_time_compression_frames = (sample_frames - 1) // self.config.transformer_temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.config.transformer_spatial_interpolation_scale,
self.config.transformer_temporal_interpolation_scale,
device=device,
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = pos_embedding.new_zeros(1, self.config.text_len + num_patches, self.inner_dim, requires_grad=False)
joint_pos_embedding.data[:, self.config.text_len :].copy_(pos_embedding)
return joint_pos_embedding
def infer(self, weights, hidden_states, timestep, encoder_hidden_states):
t_emb = get_timestep_embedding(
timestep,
self.inner_dim,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.freq_shift,
scale=self.scale,
)
t_emb = t_emb.to(dtype=hidden_states.dtype)
sample = weights.time_embedding_linear_1.apply(t_emb)
sample = torch.nn.functional.silu(sample)
emb = weights.time_embedding_linear_2.apply(sample)
text_embeds = weights.patch_embed_text_proj.apply(encoder_hidden_states)
num_frames, channels, height, width = hidden_states.shape
infer_shapes = (num_frames, channels, height, width)
p = self.config.patch_size
p_t = self.config.patch_size_t
image_embeds = hidden_states.permute(0, 2, 3, 1)
image_embeds = image_embeds.reshape(num_frames // p_t, p_t, height // p, p, width // p, p, channels)
image_embeds = image_embeds.permute(0, 2, 4, 6, 1, 3, 5).flatten(3, 6).flatten(0, 2)
image_embeds = weights.patch_embed_proj.apply(image_embeds)
embeds = torch.cat([text_embeds, image_embeds], dim=0).contiguous()
if self.use_positional_embeddings or self.config.transformer_use_learned_positional_embeddings:
if self.config.transformer_use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.config.transformer_temporal_compression_ratio + 1
if self.config.transformer_sample_height != height or self.config.transformer_sample_width != width or self.config.transformer_sample_frames != pre_time_compression_frames:
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames, device=embeds.device)[0]
else:
pos_embedding = self.pos_embedding[0]
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
hidden_states = embeds
text_seq_length = encoder_hidden_states.shape[0]
encoder_hidden_states = hidden_states[:text_seq_length, :]
hidden_states = hidden_states[text_seq_length:, :]
return hidden_states, encoder_hidden_states, emb, infer_shapes
import torch
import torch.nn.functional as F
def apply_rotary_emb(x, freqs_cis, use_real=True, use_real_unbind_dim=-1):
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None]
sin = sin[None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(2)
return x_out.type_as(x)
class CogvideoxTransformerInfer:
def __init__(self, config):
self.config = config
self.attn_type = "torch_sdpa"
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, hidden_states, encoder_hidden_states, temb):
image_rotary_emb = self.scheduler.image_rotary_emb
for i in range(self.config.transformer_num_layers):
hidden_states, encoder_hidden_states = self.infer_block(
weights.blocks_weights[i],
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
return hidden_states, encoder_hidden_states
def cogvideox_norm1(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm1_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm1_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm1_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_norm2(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm2_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm2_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm2_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_attention(self, weights, hidden_states, encoder_hidden_states, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
query = weights.attn1_to_q.apply(hidden_states)
key = weights.attn1_to_k.apply(hidden_states)
value = weights.attn1_to_v.apply(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.config.transformer_num_attention_heads
query = query.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
key = key.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
value = value.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
query = weights.attn1_norm_q.apply(query)
key = weights.attn1_norm_k.apply(key)
query[:, text_seq_length:] = apply_rotary_emb(query[:, text_seq_length:], image_rotary_emb)
key[:, text_seq_length:] = apply_rotary_emb(key[:, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query[None], key[None], value[None], attn_mask=None, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(1, -1, self.config.transformer_num_attention_heads * head_dim)
hidden_states = hidden_states.squeeze(0)
hidden_states = weights.attn1_to_out.apply(hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split([text_seq_length, hidden_states.size(0) - text_seq_length], dim=0)
return hidden_states, encoder_hidden_states
def cogvideox_ff(self, weights, hidden_states):
hidden_states = weights.ff_net_0_proj.apply(hidden_states)
hidden_states = torch.nn.functional.gelu(hidden_states, approximate="tanh")
hidden_states = weights.ff_net_2_proj.apply(hidden_states)
return hidden_states
@torch.no_grad()
def infer_block(self, weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.cogvideox_norm1(weights, hidden_states, encoder_hidden_states, temb)
attn_hidden_states, attn_encoder_hidden_states = self.cogvideox_attention(
weights,
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.cogvideox_norm2(weights, hidden_states, encoder_hidden_states, temb)
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=0)
ff_output = self.cogvideox_ff(weights, norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[text_seq_length:,]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:text_seq_length,]
return hidden_states, encoder_hidden_states
import torch
from safetensors import safe_open
import os
import glob
import math
import json
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
class CogvideoxModel:
pre_weight_class = CogvideoxPreWeights
post_weight_class = CogvideoxPostWeights
transformer_weight_class = CogvideoxTransformerWeights
def __init__(self, config):
self.config = config
self.device = torch.device("cuda")
self._init_infer_class()
self._init_weights()
self._init_infer()
def _init_infer_class(self):
self.pre_infer_class = CogvideoxPreInfer
self.post_infer_class = CogvideoxPostInfer
self.transformer_infer_class = CogvideoxTransformerInfer
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
safetensors_pattern = os.path.join(self.config.model_path, "transformer", "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
with open(os.path.join(self.config.model_path, "transformer", "config.json"), "r") as f:
transformer_cfg = json.load(f)
# init weights
self.pre_weight = self.pre_weight_class(transformer_cfg)
self.transformer_weights = self.transformer_weight_class(transformer_cfg)
self.post_weight = self.post_weight_class(transformer_cfg)
# load weights
self.pre_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
text_encoder_output = inputs["text_encoder_output"]["context"]
do_classifier_free_guidance = self.config.guidance_scale > 1.0
latent_model_input = self.scheduler.latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer(
self.pre_weight,
latent_model_input[0],
timestep,
text_encoder_output[0],
)
hidden_states, encoder_hidden_states = self.transformer_infer.infer(
self.transformer_weights,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes)
noise_pred = noise_pred.float()
if self.config.use_dynamic_cfg: # True
self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2)
if do_classifier_free_guidance: # False
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxPostWeights:
def __init__(self, config, mm_type="Default"):
self.config = config
self.mm_type = mm_type
def load_weights(self, weight_dict):
self.norm_out_linear = MM_WEIGHT_REGISTER[self.mm_type]("norm_out.linear.weight", "norm_out.linear.bias")
self.proj_out = MM_WEIGHT_REGISTER[self.mm_type]("proj_out.weight", "proj_out.bias")
self.norm_final = LN_WEIGHT_REGISTER[self.mm_type]("norm_final.weight", "norm_final.bias")
self.norm_out_norm = LN_WEIGHT_REGISTER[self.mm_type]("norm_out.norm.weight", "norm_out.norm.bias", eps=1e-5)
self.weight_list = [self.norm_out_linear, self.proj_out, self.norm_final, self.norm_out_norm]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxPreWeights:
def __init__(self, config):
self.config = config
def load_weights(self, weight_dict):
self.time_embedding_linear_1 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_1.weight", "time_embedding.linear_1.bias")
self.time_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_2.weight", "time_embedding.linear_2.bias")
self.patch_embed_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.proj.weight", "patch_embed.proj.bias")
self.patch_embed_text_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.text_proj.weight", "patch_embed.text_proj.bias")
self.weight_list = [self.time_embedding_linear_1, self.time_embedding_linear_2, self.patch_embed_proj, self.patch_embed_text_proj]
for mm_weight in self.weight_list:
mm_weight.set_config(self.config)
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class CogvideoxTransformerWeights:
def __init__(self, config, task="t2v", mm_type="Default"):
self.config = config
self.task = task
self.mm_type = mm_type
self.init()
def init(self):
self.num_layers = self.config["num_layers"]
def load_weights(self, weight_dict):
self.blocks_weights = [CogVideoXBlock(i, self.task, self.mm_type) for i in range(self.num_layers)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
def to_cpu(self):
for block in self.blocks_weights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
block.to_cuda()
class CogVideoXBlock:
def __init__(self, block_index, task="t2v", mm_type="Default"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
def load_weights(self, weight_dict):
self.attn1_to_k = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_k.weight", f"transformer_blocks.{self.block_index}.attn1.to_k.bias")
self.attn1_to_q = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_q.weight", f"transformer_blocks.{self.block_index}.attn1.to_q.bias")
self.attn1_to_v = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_v.weight", f"transformer_blocks.{self.block_index}.attn1.to_v.bias")
self.attn1_to_out = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_out.0.weight", f"transformer_blocks.{self.block_index}.attn1.to_out.0.bias")
self.ff_net_0_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.0.proj.weight", f"transformer_blocks.{self.block_index}.ff.net.0.proj.bias")
self.ff_net_2_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.2.weight", f"transformer_blocks.{self.block_index}.ff.net.2.bias")
self.norm1_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.linear.weight", f"transformer_blocks.{self.block_index}.norm1.linear.bias")
self.norm2_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.linear.weight", f"transformer_blocks.{self.block_index}.norm2.linear.bias")
self.attn1_norm_k = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_k.weight", f"transformer_blocks.{self.block_index}.attn1.norm_k.bias")
self.attn1_norm_q = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_q.weight", f"transformer_blocks.{self.block_index}.attn1.norm_q.bias")
self.norm1_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.norm.weight", f"transformer_blocks.{self.block_index}.norm1.norm.bias", eps=1e-05)
self.norm2_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.norm.weight", f"transformer_blocks.{self.block_index}.norm2.norm.bias", eps=1e-05)
self.weight_list = [
self.attn1_to_k,
self.attn1_to_q,
self.attn1_to_v,
self.attn1_to_out,
self.ff_net_0_proj,
self.ff_net_2_proj,
self.norm1_linear,
self.norm2_linear,
self.attn1_norm_k,
self.attn1_norm_q,
self.norm1_norm,
self.norm2_norm,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from diffusers.utils import export_to_video
import imageio
import numpy as np
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder]
model = CogvideoxModel(self.config)
vae_model = CogvideoxVAE(self.config)
image_encoder = None
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self):
scheduler = CogvideoxXDPMScheduler(self.config)
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
n_prompt = config.get("negative_prompt", "")
context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def set_target_shape(self):
num_frames = self.config.target_video_length
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
additional_frames = 0
patch_size_t = self.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.config.vae_scale_factor_temporal
self.config.target_shape = (
self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
def save_video(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8)
writer.append_data(frame_np)
import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import numpy as np
from lightx2v.models.schedulers.scheduler import BaseScheduler
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
return alphas_bar
class CogvideoxXDPMScheduler(BaseScheduler):
def __init__(self, config):
self.config = config
self.set_timesteps()
self.generator = torch.Generator().manual_seed(config.seed)
self.noise_pred = None
if self.config.beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(self.config.scheduler_beta_start**0.5, self.config.scheduler_beta_end**0.5, self.config.num_train_timesteps, dtype=torch.float64) ** 2
else:
raise NotImplementedError(f"{self.config.beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(torch.device("cuda"))
# Modify: SNR shift following SD3
self.alphas_cumprod = self.alphas_cumprod / (self.config.scheduler_snr_shift_scale + (1 - self.config.scheduler_snr_shift_scale) * self.alphas_cumprod)
# Rescale for zero SNR
if self.config.scheduler_rescale_betas_zero_snr:
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if self.config.scheduler_set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def scale_model_input(self, sample, timestep=None):
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def set_timesteps(self):
if self.config.num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {self.config.num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.infer_steps = self.config.num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, self.config.num_inference_steps).round()[::-1].copy().astype(np.int64)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, self.infer_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'.")
self.timesteps = torch.Tensor(timesteps).to(torch.device("cuda")).int()
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding()
def prepare_latents(self, shape, dtype):
latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.init_noise_sigma
self.latents = latents
self.old_pred_original_sample = None
def prepare_guidance(self):
self.guidance_scale = self.config.guidance_scale
def prepare_rotary_pos_embedding(self):
grid_height = self.config.height // (self.config.vae_scale_factor_spatial * self.config.patch_size)
grid_width = self.config.width // (self.config.vae_scale_factor_spatial * self.config.patch_size)
p = self.config.patch_size
p_t = self.config.patch_size_t
base_size_width = self.config.transformer_sample_width // p
base_size_height = self.config.transformer_sample_height // p
num_frames = self.latents.size(1)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=torch.device("cuda"),
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.config.transformer_attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=torch.device("cuda"),
)
self.freqs_cos = freqs_cos
self.freqs_sin = freqs_sin
self.image_rotary_emb = (freqs_cos, freqs_sin) if self.config.use_rotary_positional_embeddings else None
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
h = lamb_next - lamb
if alpha_prod_t_back is not None:
lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
h_last = lamb - lamb_previous
r = h_last / h
return h, r, lamb, lamb_next
else:
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
if alpha_prod_t_back is not None:
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
return mult1, mult2, mult3, mult4
else:
return mult1, mult2
def step_post(self):
if self.infer_steps is None:
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
timestep = self.timesteps[self.step_index]
timestep_back = self.timesteps[self.step_index - 1] if self.step_index > 0 else None
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.infer_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.scheduler_prediction_type == "epsilon":
pred_original_sample = (self.latents - beta_prod_t ** (0.5) * self.noise_pred) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.scheduler_prediction_type == "sample":
pred_original_sample = self.noise_pred
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.scheduler_prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * self.latents - (beta_prod_t**0.5) * self.noise_pred
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"prediction_type given as {self.config.scheduler_prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`")
h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
prev_sample = mult[0] * self.latents - mult[1] * pred_original_sample + mult_noise * noise
if self.old_pred_original_sample is None or prev_timestep < 0:
# Save a network evaluation if all noise levels are 0 or on the first step
self.latents = prev_sample
self.old_pred_original_sample = pred_original_sample
else:
denoised_d = mult[2] * pred_original_sample - mult[3] * self.old_pred_original_sample
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise
self.latents = x_advanced
self.old_pred_original_sample = pred_original_sample
self.latents = self.latents.to(torch.bfloat16)
...@@ -211,7 +211,7 @@ def get_1d_rotary_pos_embed_riflex( ...@@ -211,7 +211,7 @@ def get_1d_rotary_pos_embed_riflex(
if isinstance(pos, int): if isinstance(pos, int):
pos = torch.arange(pos) pos = torch.arange(pos)
if isinstance(pos, np.ndarray): if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S] pos = torch.from_numpy(pos) # [S]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2]
...@@ -223,7 +223,7 @@ def get_1d_rotary_pos_embed_riflex( ...@@ -223,7 +223,7 @@ def get_1d_rotary_pos_embed_riflex(
freqs[k - 1] = 0.9 * 2 * torch.pi / L_test freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end === # === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] freqs = torch.outer(pos, freqs) # [S, D/2]
if use_real: if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
......
import os
import glob
import torch # type: ignore
from safetensors import safe_open # type: ignore
from diffusers.video_processor import VideoProcessor # type: ignore
from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX
class CogvideoxVAE:
def __init__(self, config):
self.config = config
self.load()
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self, model_path):
safetensors_pattern = os.path.join(model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def load(self):
vae_path = os.path.join(self.config.model_path, "vae")
self.vae_config = AutoencoderKLCogVideoX.load_config(vae_path)
self.model = AutoencoderKLCogVideoX.from_config(self.vae_config)
vae_ckpt = self._load_ckpt(vae_path)
self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) # 8
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.model.load_state_dict(vae_ckpt)
self.model.to(torch.bfloat16).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@torch.no_grad()
def decode(self, latents, generator, config):
latents = latents.permute(0, 2, 1, 3, 4)
latents = 1 / self.config.vae_scaling_factor_image * latents
frames = self.model.decode(latents).sample
images = self.video_processor.postprocess_video(video=frames, output_type="pil")[0]
return images
...@@ -39,6 +39,7 @@ def set_config(args): ...@@ -39,6 +39,7 @@ def set_config(args):
model_config = json.load(f) model_config = json.load(f)
config.update(model_config) config.update(model_config)
if config.task == "i2v":
if config.target_video_length % config.vae_stride[0] != 1: if config.target_video_length % config.vae_stride[0] != 1:
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.") logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1 config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment