Commit 7ec70cbb authored by wangshankun's avatar wangshankun
Browse files

[feature]: Add CausalVid I2V

parent fc2468ce
{
"infer_steps": 20,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"num_fragments": 3,
"num_frames": 21,
"num_frame_per_block": 7,
"num_blocks": 3,
"frame_seq_length": 1560,
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74]
}
...@@ -34,7 +34,12 @@ class WanCausVidModel(WanModel): ...@@ -34,7 +34,12 @@ class WanCausVidModel(WanModel):
def _load_ckpt(self): def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = self.config.get("use_bfloat16", True)
weight_dict = torch.load(os.path.join(self.model_path, "causal_model.pt"), map_location="cpu", weights_only=True) ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt()
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
dtype = torch.bfloat16 if use_bfloat16 else None dtype = torch.bfloat16 if use_bfloat16 else None
for key, value in weight_dict.items(): for key, value in weight_dict.items():
...@@ -48,7 +53,7 @@ class WanCausVidModel(WanModel): ...@@ -48,7 +53,7 @@ class WanCausVidModel(WanModel):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
...@@ -13,18 +13,18 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -13,18 +13,18 @@ class WanTransformerInferCausVid(WanTransformerInfer):
self.num_frame_per_block = config["num_frame_per_block"] self.num_frame_per_block = config["num_frame_per_block"]
self.frame_seq_length = config["frame_seq_length"] self.frame_seq_length = config["frame_seq_length"]
self.text_len = config["text_len"] self.text_len = config["text_len"]
self.kv_size = self.num_frames * self.frame_seq_length
self.kv_cache = None self.kv_cache = None
self.crossattn_cache = None self.crossattn_cache = None
def _init_kv_cache(self, dtype, device): def _init_kv_cache(self, dtype, device):
kv_cache = [] kv_cache = []
kv_size = self.num_frames * self.frame_seq_length
for _ in range(self.blocks_num): for _ in range(self.blocks_num):
kv_cache.append( kv_cache.append(
{ {
"k": torch.zeros([self.kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device), "k": torch.zeros([kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device),
"v": torch.zeros([self.kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device), "v": torch.zeros([kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device),
} }
) )
...@@ -144,9 +144,9 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -144,9 +144,9 @@ class WanTransformerInferCausVid(WanTransformerInfer):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
# TODO: Implement I2V inference for causvid model
if self.task == "i2v": if self.task == "i2v":
raise NotImplementedError("I2V inference for causvid model is not implemented") context_img = context[:257]
context = context[257:]
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
...@@ -173,9 +173,28 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -173,9 +173,28 @@ class WanTransformerInferCausVid(WanTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
# TODO: Implement I2V inference for causvid model
if self.task == "i2v": if self.task == "i2v":
raise NotImplementedError("I2V inference for causvid model is not implemented") k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q,
k_img,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = weights.cross_attn_2.apply(
q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
)
attn_out = attn_out + img_attn_out
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
......
...@@ -25,7 +25,7 @@ class WanPreInfer: ...@@ -25,7 +25,7 @@ class WanPreInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = [self.scheduler.latents] x = [self.scheduler.latents]
if self.scheduler.flag_df: if self.scheduler.flag_df:
...@@ -42,7 +42,14 @@ class WanPreInfer: ...@@ -42,7 +42,14 @@ class WanPreInfer:
if self.task == "i2v": if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
y = [inputs["image_encoder_output"]["vae_encode_out"]]
image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段
idx_s = kv_start // frame_seq_length
idx_e = kv_end // frame_seq_length
image_encoder = image_encoder[:, idx_s:idx_e, :, :]
y = [image_encoder]
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings # embeddings
......
...@@ -23,7 +23,6 @@ import torch.distributed as dist ...@@ -23,7 +23,6 @@ import torch.distributed as dist
class WanCausVidRunner(WanRunner): class WanCausVidRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = self.model.config.denoising_step_list
self.num_frame_per_block = self.model.config.num_frame_per_block self.num_frame_per_block = self.model.config.num_frame_per_block
self.num_frames = self.model.config.num_frames self.num_frames = self.model.config.num_frames
self.frame_seq_length = self.model.config.frame_seq_length self.frame_seq_length = self.model.config.frame_seq_length
...@@ -80,7 +79,11 @@ class WanCausVidRunner(WanRunner): ...@@ -80,7 +79,11 @@ class WanCausVidRunner(WanRunner):
def set_target_shape(self): def set_target_shape(self):
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = (16, 3, self.config.lat_h, self.config.lat_w) self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w)
# i2v需根据input shape重置frame_seq_length
frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2)
self.model.transformer_infer.frame_seq_length = frame_seq_length
self.frame_seq_length = frame_seq_length
elif self.config.task == "t2v": elif self.config.task == "t2v":
self.config.target_shape = ( self.config.target_shape = (
16, 16,
......
...@@ -9,6 +9,8 @@ class WanCausVidScheduler(WanScheduler): ...@@ -9,6 +9,8 @@ class WanCausVidScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = config.denoising_step_list self.denoising_step_list = config.denoising_step_list
self.infer_steps = self.config.infer_steps
self.sample_shift = self.config.sample_shift
def prepare(self, image_encoder_output): def prepare(self, image_encoder_output):
self.generator = torch.Generator(device=self.device) self.generator = torch.Generator(device=self.device)
...@@ -19,7 +21,7 @@ class WanCausVidScheduler(WanScheduler): ...@@ -19,7 +21,7 @@ class WanCausVidScheduler(WanScheduler):
if self.config.task in ["t2v"]: if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1]) self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]: elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) self.seq_len = self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1]
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
...@@ -38,7 +40,10 @@ class WanCausVidScheduler(WanScheduler): ...@@ -38,7 +40,10 @@ class WanCausVidScheduler(WanScheduler):
self.sigma_min = self.sigmas[-1].item() self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item() self.sigma_max = self.sigmas[0].item()
self.set_denoising_timesteps(device=self.device) if len(self.denoising_step_list) == self.infer_steps: # 如果denoising_step_list有效既使用
self.set_denoising_timesteps(device=self.device)
else:
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None): def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64) self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64)
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/lightx2v/"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-14B-CausVid/"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_causvid \
--task i2v \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_causvid.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_causvid.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment