Commit 86f7f033 authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

support hunyuan i2v

parent 18532cd2
......@@ -5,12 +5,14 @@ import os
import time
import gc
import json
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching
......@@ -38,11 +40,14 @@ def load_models(args, model_config):
init_device = torch.device("cuda")
if args.model_cls == "hunyuan":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
if args.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(args.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config, init_device)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device)
model = HunyuanModel(args.model_path, model_config, init_device, args)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args)
elif args.model_cls == "wan2.1":
text_encoder = T5EncoderModel(
......@@ -69,16 +74,26 @@ def load_models(args, model_config):
return model, text_encoders, vae_model, image_encoder
def set_target_shape(args):
def set_target_shape(args, image_encoder_output):
if args.model_cls == "hunyuan":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor,
)
if args.task == "t2v":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor,
)
elif args.task == "i2v":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(image_encoder_output["target_height"]) // vae_scale_factor,
int(image_encoder_output["target_width"]) // vae_scale_factor,
)
elif args.model_cls == "wan2.1":
if args.task == "i2v":
args.target_shape = (16, 21, args.lat_h, args.lat_w)
......@@ -91,9 +106,75 @@ def set_target_shape(args):
)
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
def run_image_encoder(args, image_encoder, vae_model):
if args.model_cls == "hunyuan":
return None
img = Image.open(args.image_path).convert("RGB")
origin_size = img.size
i2v_resolution = "720p"
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, args).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
target_height, target_width = closest_size
return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}
elif args.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
......@@ -124,11 +205,14 @@ def run_image_encoder(args, image_encoder, vae_model):
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
def run_text_encoder(args, text, text_encoders, model_config):
def run_text_encoder(args, text, text_encoders, model_config, image_encoder_output):
text_encoder_output = {}
if args.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders):
text_state, attention_mask = encoder.infer(text, args)
if args.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], args)
else:
text_state, attention_mask = encoder.infer(text, args)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
......@@ -145,12 +229,12 @@ def run_text_encoder(args, text, text_encoders, model_config):
return text_encoder_output
def init_scheduler(args):
def init_scheduler(args, image_encoder_output):
if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args)
scheduler = HunyuanScheduler(args, image_encoder_output)
elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerFeatureCaching(args)
scheduler = HunyuanSchedulerFeatureCaching(args, image_encoder_output)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
......@@ -269,10 +353,10 @@ if __name__ == "__main__":
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config)
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
set_target_shape(args)
scheduler = init_scheduler(args)
set_target_shape(args, image_encoder_output)
scheduler = init_scheduler(args, image_encoder_output)
model.set_scheduler(scheduler)
......
......@@ -8,12 +8,23 @@ class HunyuanPreInfer:
def __init__(self):
self.heads_num = 24
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance):
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance, img_latents=None):
if img_latents is not None:
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.infer_time_in(weights, token_replace_t)
th = x.shape[-2] // 2
tw = x.shape[-1] // 2
frist_frame_token_num = th * tw
time_out = self.infer_time_in(weights, t)
img_out = self.infer_img_in(weights, x)
infer_text_out = self.infer_text_in(weights, text_states, text_mask, t)
infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out
if img_latents is not None:
token_replace_vec = token_replace_vec + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance)
vec = vec + guidance_out
......@@ -32,6 +43,8 @@ class HunyuanPreInfer:
cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len
if img_latents is not None:
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
def infer_time_in(self, weights, t):
......
......@@ -25,10 +25,10 @@ class HunyuanTransformerInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
......@@ -75,38 +75,22 @@ class HunyuanTransformerInfer:
img = x[:img_seq_len, ...]
return img, vec
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(
weights.double_blocks_weights[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
x = self.infer_single_block(
weights.single_blocks_weights[i],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
......@@ -119,6 +103,13 @@ class HunyuanTransformerInfer:
img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1)
if token_replace_vec is not None:
token_replace_vec_silu = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_img_mod_out = weights.img_mod.apply(token_replace_vec_silu)
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = token_replace_vec_img_mod_out.chunk(6, dim=-1)
else:
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = None, None, None, None, None, None
txt_mod_out = weights.txt_mod.apply(vec_silu)
(
txt_mod1_shift,
......@@ -129,7 +120,7 @@ class HunyuanTransformerInfer:
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
......@@ -162,28 +153,19 @@ class HunyuanTransformerInfer:
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num
)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis):
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
if tr_img_mod1_scale is not None:
x_zero = img_modulated[:frist_frame_token_num] * (1 + tr_img_mod1_scale) + tr_img_mod1_shift
x_orig = img_modulated[frist_frame_token_num:] * (1 + img_mod1_scale) + img_mod1_shift
img_modulated = torch.concat((x_zero, x_orig), dim=0)
else:
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
......@@ -206,21 +188,24 @@ class HunyuanTransformerInfer:
return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num
):
out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate
if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * tr_img_mod1_gate
x_orig = out[frist_frame_token_num:] * img_mod1_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * img_mod1_gate
img = img + out
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift
if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_img_mod2_scale) + tr_img_mod2_shift
x_orig = out[frist_frame_token_num:] * (1 + img_mod2_scale) + img_mod2_shift
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out)
......@@ -251,13 +236,23 @@ class HunyuanTransformerInfer:
txt = txt + out
return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
if token_replace_vec is not None:
token_replace_vec_out = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_out = weights.modulation.apply(token_replace_vec_out)
tr_mod_shift, tr_mod_scale, tr_mod_gate = token_replace_vec_out.chunk(3, dim=-1)
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift
if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_mod_scale) + tr_mod_shift
x_orig = out[frist_frame_token_num:] * (1 + mod_scale) + mod_shift
x_mod = torch.concat((x_zero, x_orig), dim=0)
else:
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod)
......@@ -301,6 +296,12 @@ class HunyuanTransformerInfer:
out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out)
out = out * mod_gate
if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * tr_mod_gate
x_orig = out[frist_frame_token_num:] * mod_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * mod_gate
x = x + out
return x
......@@ -17,10 +17,11 @@ class HunyuanModel:
post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config, device):
def __init__(self, model_path, config, device, args):
self.model_path = model_path
self.config = config
self.device = device
self.args = args
self._init_infer_class()
self._init_weights()
self._init_infer()
......@@ -47,7 +48,10 @@ class HunyuanModel:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_ckpt(self):
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
if self.args.task == "t2v":
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
else:
ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict
......@@ -96,6 +100,7 @@ class HunyuanModel:
self.scheduler.freqs_cos,
self.scheduler.freqs_sin,
self.scheduler.guidance,
img_latents=image_encoder_output["img_latents"] if "img_latents" in image_encoder_output else None,
)
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
......
import torch
import numpy as np
from diffusers.utils.torch_utils import randn_tensor
from typing import Union, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
......@@ -174,35 +175,108 @@ def get_nd_rotary_pos_embed(
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device)
timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32, device=device)
return timesteps, sigmas
def get_1d_rotary_pos_embed_riflex(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
k: Optional[int] = None,
L_test: Optional[int] = None,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if k is not None:
freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class HunyuanScheduler(BaseScheduler):
def __init__(self, args):
def __init__(self, args, image_encoder_output):
super().__init__(args)
self.infer_steps = self.args.infer_steps
self.image_encoder_output = image_encoder_output
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]]
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.args.seed]]
self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16)
self.prepare_latents(shape=self.args.target_shape, dtype=torch.float16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width)
if self.args.task == "t2v":
target_height, target_width = self.args.target_height, self.args.target_width
else:
target_height, target_width = self.image_encoder_output["target_height"], self.image_encoder_output["target_width"]
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=target_height, width=target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
def step_post(self):
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
prev_sample = sample + self.noise_pred.to(torch.float32) * dt
self.latents = prev_sample
if self.args.task == "t2v":
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
self.latents = sample + self.noise_pred.to(torch.float32) * dt
else:
sample = self.latents[:, :, 1:, :, :].to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
def prepare_latents(self, shape, dtype):
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
if self.args.task == "t2v":
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
else:
x1 = self.image_encoder_output["img_latents"].repeat(1, 1, (self.args.target_video_length - 1) // 4 + 1, 1, 1)
x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
t = torch.tensor([0.999]).to(device=torch.device("cuda"))
self.latents = x0 * t + x1 * (1 - t)
self.latents = self.latents.to(dtype=dtype)
self.latents = torch.concat([self.image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3
......@@ -230,17 +304,62 @@ class HunyuanScheduler(BaseScheduler):
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
if self.args.task == "t2v":
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
else:
L_test = rope_sizes[0] # Latent frames
L_train = 25 # Training length from HunyuanVideo
actual_num_frames = video_length # Use input video_length directly
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"
if actual_num_frames > 192:
k = 2 + ((actual_num_frames + 3) // (4 * L_train))
k = max(4, min(8, k))
# Compute positional grids for RIFLEx
axes_grids = [torch.arange(size, device=torch.device("cuda"), dtype=torch.float32) for size in rope_sizes]
grid = torch.meshgrid(*axes_grids, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, t, h, w]
pos = grid.reshape(3, -1).t() # [t * h * w, 3]
# Apply RIFLEx to temporal dimension
freqs = []
for i in range(3):
if i == 0: # Temporal with RIFLEx
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=k, L_test=L_test)
else: # Spatial with default RoPE
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=None, L_test=None)
freqs.append((freqs_cos, freqs_sin))
freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
else:
# 20250316 pftq: Original code for <= 192 frames
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
"""generate crop size list
Args:
base_size (int, optional): the base size for generate bucket. Defaults to 256.
patch_size (int, optional): the stride to generate bucket. Defaults to 32.
max_ratio (float, optional): th max ratio for h or w based on base_size . Defaults to 4.0.
Returns:
list: generate crop size list
"""
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
"""get the closest ratio in the buckets
Args:
height (float): video height
width (float): video width
ratios (list): video aspect ratio
buckets (list): buckets generate by `generate_crop_size_list`
Returns:
the closest ratio in the buckets and the corresponding ratio
"""
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
class TextEncoderHFLlavaModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
self.init()
self.load()
def init(self):
self.max_length = 359
self.hidden_state_skip_layer = 2
self.crop_start = 103
self.double_return_token_id = 271
self.image_emb_len = 576
self.text_crop_start = self.crop_start - 1 + self.image_emb_len
self.image_crop_start = 5
self.image_crop_end = 581
self.image_embed_interleave = 4
self.prompt_template = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
def load(self):
self.model = LlavaForConditionalGeneration.from_pretrained(self.model_path, low_cpu_mem_usage=True).to(torch.float16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")
self.processor = CLIPImageProcessor.from_pretrained(self.model_path)
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, img, args):
# if args.cpu_offload:
# self.to_cuda()
text = self.prompt_template.format(text)
print(f"text: {text}")
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")
image_outputs = self.processor(img, return_tensors="pt")["pixel_values"].to(self.device)
attention_mask = tokens["attention_mask"].to(self.device)
outputs = self.model(input_ids=tokens["input_ids"], attention_mask=attention_mask, output_hidden_states=True, pixel_values=image_outputs)
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)]
batch_indices, last_double_return_token_indices = torch.where(tokens["input_ids"] == self.double_return_token_id)
last_double_return_token_indices = last_double_return_token_indices.reshape(1, -1)[:, -1]
assistant_crop_start = last_double_return_token_indices - 1 + self.image_emb_len - 4
assistant_crop_end = last_double_return_token_indices - 1 + self.image_emb_len
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = torch.cat([last_hidden_state[0, self.text_crop_start : assistant_crop_start[0].item()], last_hidden_state[0, assistant_crop_end[0].item() :]])
text_attention_mask = torch.cat([attention_mask[0, self.crop_start : attention_mask_assistant_crop_start[0].item()], attention_mask[0, attention_mask_assistant_crop_end[0].item() :]])
image_last_hidden_state = last_hidden_state[0, self.image_crop_start : self.image_crop_end]
image_attention_mask = torch.ones(image_last_hidden_state.shape[0]).to(last_hidden_state.device).to(attention_mask.dtype)
text_last_hidden_state.unsqueeze_(0)
text_attention_mask.unsqueeze_(0)
image_last_hidden_state.unsqueeze_(0)
image_attention_mask.unsqueeze_(0)
image_last_hidden_state = image_last_hidden_state[:, :: self.image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, :: self.image_embed_interleave]
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
# if args.cpu_offload:
# self.to_cpu()
return last_hidden_state, attention_mask
if __name__ == "__main__":
model = TextEncoderHFLlavaModel("/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/i2v/text_encoder_i2v", torch.device("cuda"))
text = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img = Image.open(img_path).convert("RGB")
outputs = model.infer(text, img, None)
print(outputs)
import os
import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
def __init__(self, model_path, dtype, device, args):
self.model_path = model_path
self.dtype = dtype
self.device = device
self.args = args
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
if self.args.task == "t2v":
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
else:
self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
......@@ -39,6 +43,12 @@ class VideoEncoderKLCausal3DModel:
self.to_cpu()
return image
def encode(self, x, args):
h = self.model.encoder(x)
moments = self.model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
if __name__ == "__main__":
model_path = ""
......
#!/bin/bash
# set path and first
lightx2v_path=""
model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: 0, change at shell script or set env variable."
cuda_devices="0"
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--model_path $model_path \
--task i2v \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--save_video_path ./output_lightx2v_hy_i2v.mp4 \
--seed 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment