Commit 3e67df1c authored by helloyongyang's avatar helloyongyang
Browse files

[Feature]: Support wan i2v changing resolution

parent d9b5cedd
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 442,
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false,
"changing_resolution": true,
"resolution_rate": 0.75,
"changing_resolution_steps": 20
}
...@@ -6,6 +6,7 @@ from lightx2v.utils.envs import * ...@@ -6,6 +6,7 @@ from lightx2v.utils.envs import *
class WanPreInfer: class WanPreInfer:
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
...@@ -43,7 +44,11 @@ class WanPreInfer: ...@@ -43,7 +44,11 @@ class WanPreInfer:
if self.task == "i2v": if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
if self.config.get("changing_resolution", False) and self.scheduler.step_index > self.config.changing_resolution_steps - 1:
image_encoder = inputs["image_encoder_output"]["vae_encode_out_original_resolution"]
else:
image_encoder = inputs["image_encoder_output"]["vae_encode_out"] image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2) frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段 if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段
idx_s = kv_start // frame_seq_length idx_s = kv_start // frame_seq_length
......
...@@ -208,10 +208,23 @@ class WanRunner(DefaultRunner): ...@@ -208,10 +208,23 @@ class WanRunner(DefaultRunner):
max_area = self.config.target_height * self.config.target_width max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1]) lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2]) lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
if self.config.get("changing_resolution", False):
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_original_resolution = self.get_vae_encoder_output(img, lat_h, lat_w)
# get vae encode out at low resolution
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate) // 2 * 2
vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
return vae_encode_out, vae_encode_out_original_resolution # low resolution, original resolution
else:
self.config.lat_h, self.config.lat_w = lat_h, lat_w self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
return vae_encode_out
def get_vae_encoder_output(self, img, lat_h, lat_w):
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
msk = torch.ones( msk = torch.ones(
1, 1,
...@@ -246,10 +259,18 @@ class WanRunner(DefaultRunner): ...@@ -246,10 +259,18 @@ class WanRunner(DefaultRunner):
return vae_encode_out return vae_encode_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
if self.config.get("changing_resolution", False):
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out[0],
"vae_encode_out_original_resolution": vae_encode_out[1],
}
else:
image_encoder_output = { image_encoder_output = {
"clip_encoder_out": clip_encoder_out, "clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out, "vae_encode_out": vae_encode_out,
} }
return { return {
"text_encoder_output": text_encoder_output, "text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output, "image_encoder_output": image_encoder_output,
......
...@@ -30,7 +30,7 @@ class WanScheduler4ChangingResolution(WanScheduler): ...@@ -30,7 +30,7 @@ class WanScheduler4ChangingResolution(WanScheduler):
) )
def step_post(self): def step_post(self):
if self.step_index == self.changing_resolution_steps: if self.step_index == self.changing_resolution_steps - 1:
self.step_post_upsample() self.step_post_upsample()
else: else:
super().step_post() super().step_post()
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 # remove this can get high quality video
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/changing_resolution/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_changing_resolution.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment