Commit 86f7f033 authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

support hunyuan i2v

parent 18532cd2
...@@ -5,12 +5,14 @@ import os ...@@ -5,12 +5,14 @@ import os
import time import time
import gc import gc
import json import json
import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching
...@@ -38,11 +40,14 @@ def load_models(args, model_config): ...@@ -38,11 +40,14 @@ def load_models(args, model_config):
init_device = torch.device("cuda") init_device = torch.device("cuda")
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device) if args.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(args.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device) text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2] text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config, init_device) model = HunyuanModel(args.model_path, model_config, init_device, args)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device) vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args)
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
...@@ -69,16 +74,26 @@ def load_models(args, model_config): ...@@ -69,16 +74,26 @@ def load_models(args, model_config):
return model, text_encoders, vae_model, image_encoder return model, text_encoders, vae_model, image_encoder
def set_target_shape(args): def set_target_shape(args, image_encoder_output):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
vae_scale_factor = 2 ** (4 - 1) if args.task == "t2v":
args.target_shape = ( vae_scale_factor = 2 ** (4 - 1)
1, args.target_shape = (
16, 1,
(args.target_video_length - 1) // 4 + 1, 16,
int(args.target_height) // vae_scale_factor, (args.target_video_length - 1) // 4 + 1,
int(args.target_width) // vae_scale_factor, int(args.target_height) // vae_scale_factor,
) int(args.target_width) // vae_scale_factor,
)
elif args.task == "i2v":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(image_encoder_output["target_height"]) // vae_scale_factor,
int(image_encoder_output["target_width"]) // vae_scale_factor,
)
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
if args.task == "i2v": if args.task == "i2v":
args.target_shape = (16, 21, args.lat_h, args.lat_w) args.target_shape = (16, 21, args.lat_h, args.lat_w)
...@@ -91,9 +106,75 @@ def set_target_shape(args): ...@@ -91,9 +106,75 @@ def set_target_shape(args):
) )
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
def run_image_encoder(args, image_encoder, vae_model): def run_image_encoder(args, image_encoder, vae_model):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
return None img = Image.open(args.image_path).convert("RGB")
origin_size = img.size
i2v_resolution = "720p"
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, args).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
target_height, target_width = closest_size
return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB") img = Image.open(args.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
...@@ -124,11 +205,14 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -124,11 +205,14 @@ def run_image_encoder(args, image_encoder, vae_model):
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
def run_text_encoder(args, text, text_encoders, model_config): def run_text_encoder(args, text, text_encoders, model_config, image_encoder_output):
text_encoder_output = {} text_encoder_output = {}
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders): for i, encoder in enumerate(text_encoders):
text_state, attention_mask = encoder.infer(text, args) if args.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], args)
else:
text_state, attention_mask = encoder.infer(text, args)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16) text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
...@@ -145,12 +229,12 @@ def run_text_encoder(args, text, text_encoders, model_config): ...@@ -145,12 +229,12 @@ def run_text_encoder(args, text, text_encoders, model_config):
return text_encoder_output return text_encoder_output
def init_scheduler(args): def init_scheduler(args, image_encoder_output):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching": if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args) scheduler = HunyuanScheduler(args, image_encoder_output)
elif args.feature_caching == "TaylorSeer": elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerFeatureCaching(args) scheduler = HunyuanSchedulerFeatureCaching(args, image_encoder_output)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
...@@ -269,10 +353,10 @@ if __name__ == "__main__": ...@@ -269,10 +353,10 @@ if __name__ == "__main__":
else: else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None} image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config) text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
set_target_shape(args) set_target_shape(args, image_encoder_output)
scheduler = init_scheduler(args) scheduler = init_scheduler(args, image_encoder_output)
model.set_scheduler(scheduler) model.set_scheduler(scheduler)
......
...@@ -8,12 +8,23 @@ class HunyuanPreInfer: ...@@ -8,12 +8,23 @@ class HunyuanPreInfer:
def __init__(self): def __init__(self):
self.heads_num = 24 self.heads_num = 24
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance): def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance, img_latents=None):
if img_latents is not None:
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.infer_time_in(weights, token_replace_t)
th = x.shape[-2] // 2
tw = x.shape[-1] // 2
frist_frame_token_num = th * tw
time_out = self.infer_time_in(weights, t) time_out = self.infer_time_in(weights, t)
img_out = self.infer_img_in(weights, x) img_out = self.infer_img_in(weights, x)
infer_text_out = self.infer_text_in(weights, text_states, text_mask, t) infer_text_out = self.infer_text_in(weights, text_states, text_mask, t)
infer_vector_out = self.infer_vector_in(weights, text_states_2) infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out vec = time_out + infer_vector_out
if img_latents is not None:
token_replace_vec = token_replace_vec + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance) guidance_out = self.infer_guidance_in(weights, guidance)
vec = vec + guidance_out vec = vec + guidance_out
...@@ -32,6 +43,8 @@ class HunyuanPreInfer: ...@@ -32,6 +43,8 @@ class HunyuanPreInfer:
cu_seqlens_qkv[2 * i + 2] = s2 cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len max_seqlen_qkv = img_seq_len + txt_seq_len
if img_latents is not None:
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin) return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
def infer_time_in(self, weights, t): def infer_time_in(self, weights, t):
......
...@@ -25,10 +25,10 @@ class HunyuanTransformerInfer: ...@@ -25,10 +25,10 @@ class HunyuanTransformerInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
...@@ -75,38 +75,22 @@ class HunyuanTransformerInfer: ...@@ -75,38 +75,22 @@ class HunyuanTransformerInfer:
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
for i in range(self.double_blocks_num): for i in range(self.double_blocks_num):
img, txt = self.infer_double_block( img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
weights.double_blocks_weights[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num): for i in range(self.single_blocks_num):
x = self.infer_single_block( x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
weights.single_blocks_weights[i],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu) img_mod_out = weights.img_mod.apply(vec_silu)
...@@ -119,6 +103,13 @@ class HunyuanTransformerInfer: ...@@ -119,6 +103,13 @@ class HunyuanTransformerInfer:
img_mod2_gate, img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1) ) = img_mod_out.chunk(6, dim=-1)
if token_replace_vec is not None:
token_replace_vec_silu = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_img_mod_out = weights.img_mod.apply(token_replace_vec_silu)
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = token_replace_vec_img_mod_out.chunk(6, dim=-1)
else:
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = None, None, None, None, None, None
txt_mod_out = weights.txt_mod.apply(vec_silu) txt_mod_out = weights.txt_mod.apply(vec_silu)
( (
txt_mod1_shift, txt_mod1_shift,
...@@ -129,7 +120,7 @@ class HunyuanTransformerInfer: ...@@ -129,7 +120,7 @@ class HunyuanTransformerInfer:
txt_mod2_gate, txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1) ) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis) img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift) txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
...@@ -162,28 +153,19 @@ class HunyuanTransformerInfer: ...@@ -162,28 +153,19 @@ class HunyuanTransformerInfer:
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten( img = self.infer_double_block_img_post_atten(
weights, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) )
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
return img, txt return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis): def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift if tr_img_mod1_scale is not None:
x_zero = img_modulated[:frist_frame_token_num] * (1 + tr_img_mod1_scale) + tr_img_mod1_shift
x_orig = img_modulated[frist_frame_token_num:] * (1 + img_mod1_scale) + img_mod1_shift
img_modulated = torch.concat((x_zero, x_orig), dim=0)
else:
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated) img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
...@@ -206,21 +188,24 @@ class HunyuanTransformerInfer: ...@@ -206,21 +188,24 @@ class HunyuanTransformerInfer:
return txt_q, txt_k, txt_v return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten( def infer_double_block_img_post_atten(
self, self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
): ):
out = weights.img_attn_proj.apply(img_attn) out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * tr_img_mod1_gate
x_orig = out[frist_frame_token_num:] * img_mod1_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * img_mod1_gate
img = img + out img = img + out
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_img_mod2_scale) + tr_img_mod2_shift
x_orig = out[frist_frame_token_num:] * (1 + img_mod2_scale) + img_mod2_shift
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out) out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out) out = weights.img_mlp_fc2.apply(out)
...@@ -251,13 +236,23 @@ class HunyuanTransformerInfer: ...@@ -251,13 +236,23 @@ class HunyuanTransformerInfer:
txt = txt + out txt = txt + out
return txt return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out) out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
if token_replace_vec is not None:
token_replace_vec_out = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_out = weights.modulation.apply(token_replace_vec_out)
tr_mod_shift, tr_mod_scale, tr_mod_gate = token_replace_vec_out.chunk(3, dim=-1)
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_mod_scale) + tr_mod_shift
x_orig = out[frist_frame_token_num:] * (1 + mod_scale) + mod_shift
x_mod = torch.concat((x_zero, x_orig), dim=0)
else:
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod) x_mod = weights.linear1.apply(x_mod)
...@@ -301,6 +296,12 @@ class HunyuanTransformerInfer: ...@@ -301,6 +296,12 @@ class HunyuanTransformerInfer:
out = torch.nn.functional.gelu(mlp, approximate="tanh") out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1) out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out) out = weights.linear2.apply(out)
out = out * mod_gate
if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * tr_mod_gate
x_orig = out[frist_frame_token_num:] * mod_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * mod_gate
x = x + out x = x + out
return x return x
...@@ -17,10 +17,11 @@ class HunyuanModel: ...@@ -17,10 +17,11 @@ class HunyuanModel:
post_weight_class = HunyuanPostWeights post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, args):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.device = device self.device = device
self.args = args
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
...@@ -47,7 +48,10 @@ class HunyuanModel: ...@@ -47,7 +48,10 @@ class HunyuanModel:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_ckpt(self): def _load_ckpt(self):
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt") if self.args.task == "t2v":
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
else:
ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"] weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict return weight_dict
...@@ -96,6 +100,7 @@ class HunyuanModel: ...@@ -96,6 +100,7 @@ class HunyuanModel:
self.scheduler.freqs_cos, self.scheduler.freqs_cos,
self.scheduler.freqs_sin, self.scheduler.freqs_sin,
self.scheduler.guidance, self.scheduler.guidance,
img_latents=image_encoder_output["img_latents"] if "img_latents" in image_encoder_output else None,
) )
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out) img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape) self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
......
import torch import torch
import numpy as np
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from typing import Union, Tuple, List from typing import Union, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Union, Tuple from typing import Any, Callable, Dict, List, Optional, Union, Tuple
...@@ -174,35 +175,108 @@ def get_nd_rotary_pos_embed( ...@@ -174,35 +175,108 @@ def get_nd_rotary_pos_embed(
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000): def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1) sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas) sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device) timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32, device=device)
return timesteps, sigmas return timesteps, sigmas
def get_1d_rotary_pos_embed_riflex(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
k: Optional[int] = None,
L_test: Optional[int] = None,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if k is not None:
freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class HunyuanScheduler(BaseScheduler): class HunyuanScheduler(BaseScheduler):
def __init__(self, args): def __init__(self, args, image_encoder_output):
super().__init__(args) super().__init__(args)
self.infer_steps = self.args.infer_steps self.infer_steps = self.args.infer_steps
self.image_encoder_output = image_encoder_output
self.shift = 7.0 self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda")) self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0 self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]] self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.args.seed]]
self.noise_pred = None self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16) self.prepare_latents(shape=self.args.target_shape, dtype=torch.float16)
self.prepare_guidance() self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width) if self.args.task == "t2v":
target_height, target_width = self.args.target_height, self.args.target_width
else:
target_height, target_width = self.image_encoder_output["target_height"], self.image_encoder_output["target_width"]
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=target_height, width=target_width)
def prepare_guidance(self): def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0 self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
def step_post(self): def step_post(self):
sample = self.latents.to(torch.float32) if self.args.task == "t2v":
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] sample = self.latents.to(torch.float32)
prev_sample = sample + self.noise_pred.to(torch.float32) * dt dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
self.latents = prev_sample self.latents = sample + self.noise_pred.to(torch.float32) * dt
else:
sample = self.latents[:, :, 1:, :, :].to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
def prepare_latents(self, shape, dtype): def prepare_latents(self, shape, dtype):
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype) if self.args.task == "t2v":
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
else:
x1 = self.image_encoder_output["img_latents"].repeat(1, 1, (self.args.target_video_length - 1) // 4 + 1, 1, 1)
x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
t = torch.tensor([0.999]).to(device=torch.device("cuda"))
self.latents = x0 * t + x1 * (1 - t)
self.latents = self.latents.to(dtype=dtype)
self.latents = torch.concat([self.image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
def prepare_rotary_pos_embedding(self, video_length, height, width): def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3 target_ndim = 3
...@@ -230,17 +304,62 @@ class HunyuanScheduler(BaseScheduler): ...@@ -230,17 +304,62 @@ class HunyuanScheduler(BaseScheduler):
if len(rope_sizes) != target_ndim: if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list if self.args.task == "t2v":
if rope_dim_list is None: head_dim = hidden_size // heads_num
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] rope_dim_list = rope_dim_list
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" if rope_dim_list is None:
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed( rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
rope_dim_list, assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
rope_sizes, self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
theta=rope_theta, rope_dim_list,
use_real=True, rope_sizes,
theta_rescale_factor=1, theta=rope_theta,
) use_real=True,
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda")) theta_rescale_factor=1,
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda")) )
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
else:
L_test = rope_sizes[0] # Latent frames
L_train = 25 # Training length from HunyuanVideo
actual_num_frames = video_length # Use input video_length directly
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"
if actual_num_frames > 192:
k = 2 + ((actual_num_frames + 3) // (4 * L_train))
k = max(4, min(8, k))
# Compute positional grids for RIFLEx
axes_grids = [torch.arange(size, device=torch.device("cuda"), dtype=torch.float32) for size in rope_sizes]
grid = torch.meshgrid(*axes_grids, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, t, h, w]
pos = grid.reshape(3, -1).t() # [t * h * w, 3]
# Apply RIFLEx to temporal dimension
freqs = []
for i in range(3):
if i == 0: # Temporal with RIFLEx
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=k, L_test=L_test)
else: # Spatial with default RoPE
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=None, L_test=None)
freqs.append((freqs_cos, freqs_sin))
freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
else:
# 20250316 pftq: Original code for <= 192 frames
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from transformers import LlavaForConditionalGeneration, CLIPImageProcessor, AutoTokenizer
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
"""generate crop size list
Args:
base_size (int, optional): the base size for generate bucket. Defaults to 256.
patch_size (int, optional): the stride to generate bucket. Defaults to 32.
max_ratio (float, optional): th max ratio for h or w based on base_size . Defaults to 4.0.
Returns:
list: generate crop size list
"""
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
"""get the closest ratio in the buckets
Args:
height (float): video height
width (float): video width
ratios (list): video aspect ratio
buckets (list): buckets generate by `generate_crop_size_list`
Returns:
the closest ratio in the buckets and the corresponding ratio
"""
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
class TextEncoderHFLlavaModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
self.init()
self.load()
def init(self):
self.max_length = 359
self.hidden_state_skip_layer = 2
self.crop_start = 103
self.double_return_token_id = 271
self.image_emb_len = 576
self.text_crop_start = self.crop_start - 1 + self.image_emb_len
self.image_crop_start = 5
self.image_crop_end = 581
self.image_embed_interleave = 4
self.prompt_template = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
def load(self):
self.model = LlavaForConditionalGeneration.from_pretrained(self.model_path, low_cpu_mem_usage=True).to(torch.float16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")
self.processor = CLIPImageProcessor.from_pretrained(self.model_path)
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, img, args):
# if args.cpu_offload:
# self.to_cuda()
text = self.prompt_template.format(text)
print(f"text: {text}")
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")
image_outputs = self.processor(img, return_tensors="pt")["pixel_values"].to(self.device)
attention_mask = tokens["attention_mask"].to(self.device)
outputs = self.model(input_ids=tokens["input_ids"], attention_mask=attention_mask, output_hidden_states=True, pixel_values=image_outputs)
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)]
batch_indices, last_double_return_token_indices = torch.where(tokens["input_ids"] == self.double_return_token_id)
last_double_return_token_indices = last_double_return_token_indices.reshape(1, -1)[:, -1]
assistant_crop_start = last_double_return_token_indices - 1 + self.image_emb_len - 4
assistant_crop_end = last_double_return_token_indices - 1 + self.image_emb_len
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = torch.cat([last_hidden_state[0, self.text_crop_start : assistant_crop_start[0].item()], last_hidden_state[0, assistant_crop_end[0].item() :]])
text_attention_mask = torch.cat([attention_mask[0, self.crop_start : attention_mask_assistant_crop_start[0].item()], attention_mask[0, attention_mask_assistant_crop_end[0].item() :]])
image_last_hidden_state = last_hidden_state[0, self.image_crop_start : self.image_crop_end]
image_attention_mask = torch.ones(image_last_hidden_state.shape[0]).to(last_hidden_state.device).to(attention_mask.dtype)
text_last_hidden_state.unsqueeze_(0)
text_attention_mask.unsqueeze_(0)
image_last_hidden_state.unsqueeze_(0)
image_attention_mask.unsqueeze_(0)
image_last_hidden_state = image_last_hidden_state[:, :: self.image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, :: self.image_embed_interleave]
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
# if args.cpu_offload:
# self.to_cpu()
return last_hidden_state, attention_mask
if __name__ == "__main__":
model = TextEncoderHFLlavaModel("/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/i2v/text_encoder_i2v", torch.device("cuda"))
text = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img = Image.open(img_path).convert("RGB")
outputs = model.infer(text, img, None)
print(outputs)
import os import os
import torch import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution
class VideoEncoderKLCausal3DModel: class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device): def __init__(self, model_path, dtype, device, args):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.args = args
self.load() self.load()
def load(self): def load(self):
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae") if self.args.task == "t2v":
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
else:
self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path) config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config) self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True) ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
...@@ -39,6 +43,12 @@ class VideoEncoderKLCausal3DModel: ...@@ -39,6 +43,12 @@ class VideoEncoderKLCausal3DModel:
self.to_cpu() self.to_cpu()
return image return image
def encode(self, x, args):
h = self.model.encoder(x)
moments = self.model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
if __name__ == "__main__": if __name__ == "__main__":
model_path = "" model_path = ""
......
#!/bin/bash
# set path and first
lightx2v_path=""
model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: 0, change at shell script or set env variable."
cuda_devices="0"
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--model_path $model_path \
--task i2v \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--save_video_path ./output_lightx2v_hy_i2v.mp4 \
--seed 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment