Commit 51bdee9e authored by helloyongyang's avatar helloyongyang
Browse files

Support wan2.2 moe i2v model

parent 93b4f6a4
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false
}
...@@ -43,7 +43,8 @@ class WanPreInfer: ...@@ -43,7 +43,8 @@ class WanPreInfer:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
if self.task == "i2v": if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
if self.config.get("changing_resolution", False): if self.config.get("changing_resolution", False):
image_encoder = inputs["image_encoder_output"]["vae_encode_out"][self.scheduler.changing_resolution_index] image_encoder = inputs["image_encoder_output"]["vae_encode_out"][self.scheduler.changing_resolution_index]
...@@ -103,7 +104,7 @@ class WanPreInfer: ...@@ -103,7 +104,7 @@ class WanPreInfer:
del out, stacked del out, stacked
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.task == "i2v": if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del clip_fea del clip_fea
...@@ -116,7 +117,8 @@ class WanPreInfer: ...@@ -116,7 +117,8 @@ class WanPreInfer:
context_clip = weights.proj_4.apply(context_clip) context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0) context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del context_clip if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache() torch.cuda.empty_cache()
return ( return (
embed, embed,
......
...@@ -403,7 +403,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -403,7 +403,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze(0)) x.add_(y_out * gate_msa.squeeze(0))
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
else: else:
...@@ -411,7 +411,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -411,7 +411,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
context = context.to(torch.bfloat16) context = context.to(torch.bfloat16)
if self.task == "i2v": if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context_img.to(torch.bfloat16) context_img = context_img.to(torch.bfloat16)
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
...@@ -434,7 +434,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -434,7 +434,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
if self.task == "i2v" and context_img is not None: if self.task == "i2v" and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
...@@ -39,7 +39,7 @@ class WanPreWeights(WeightModule): ...@@ -39,7 +39,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
) )
if config.task == "i2v": if config.task == "i2v" and config.get("use_image_encoder", True):
self.add_module( self.add_module(
"proj_0", "proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"), LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
......
...@@ -286,7 +286,7 @@ class WanCrossAttention(WeightModule): ...@@ -286,7 +286,7 @@ class WanCrossAttention(WeightModule):
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v": if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
self.add_module( self.add_module(
"cross_attn_k_img", "cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
......
...@@ -152,7 +152,7 @@ class DefaultRunner(BaseRunner): ...@@ -152,7 +152,7 @@ class DefaultRunner(BaseRunner):
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -50,7 +50,7 @@ class WanRunner(DefaultRunner): ...@@ -50,7 +50,7 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self): def load_image_encoder(self):
image_encoder = None image_encoder = None
if self.config.task == "i2v": if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
# quant_config # quant_config
clip_quantized = self.config.get("clip_quantized", False) clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized: if clip_quantized:
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment