Commit 51bdee9e authored by helloyongyang's avatar helloyongyang
Browse files

Support wan2.2 moe i2v model

parent 93b4f6a4
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false
}
......@@ -43,6 +43,7 @@ class WanPreInfer:
context = inputs["text_encoder_output"]["context_null"]
if self.task == "i2v":
if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
if self.config.get("changing_resolution", False):
......@@ -103,7 +104,7 @@ class WanPreInfer:
del out, stacked
torch.cuda.empty_cache()
if self.task == "i2v":
if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache:
del clip_fea
......@@ -116,6 +117,7 @@ class WanPreInfer:
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache()
return (
......
......@@ -403,7 +403,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze(0))
norm3_out = weights.norm3.apply(x)
if self.task == "i2v":
if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context[:257]
context = context[257:]
else:
......@@ -411,7 +411,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if GET_DTYPE() != "BF16":
context = context.to(torch.bfloat16)
if self.task == "i2v":
if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context_img.to(torch.bfloat16)
n, d = self.num_heads, self.head_dim
......@@ -434,7 +434,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"],
)
if self.task == "i2v" and context_img is not None:
if self.task == "i2v" and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
......@@ -39,7 +39,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
if config.task == "i2v":
if config.task == "i2v" and config.get("use_image_encoder", True):
self.add_module(
"proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
......
......@@ -286,7 +286,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v":
if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
......
......@@ -152,7 +152,7 @@ class DefaultRunner(BaseRunner):
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img)
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache()
......
......@@ -50,7 +50,7 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self):
image_encoder = None
if self.config.task == "i2v":
if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment