Commit c37065b1 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

update scheduler for wan22_moe_distill latest models (#304)

parent e0645b37
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"lora_configs": [
{
"name": "low_noise_model",
"path": "Wan2.1-I2V-14B-480P/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
"strength": 1.0
}
]
}
{ {
"infer_steps": 6, "infer_steps": 4,
"target_video_length": 81, "target_video_length": 81,
"text_len": 512, "text_len": 512,
"target_height": 480, "target_height": 480,
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"boundary_step_index": 4, "boundary_step_index": 2,
"denoising_step_list": [1000, 875, 750, 625, 500, 250], "denoising_step_list": [1000, 750, 500, 250],
"lora_configs": [ "lora_configs": [
{ {
"name": "low_noise_model", "name": "low_noise_model",
......
...@@ -65,13 +65,29 @@ class Wan22StepDistillScheduler(WanStepDistillScheduler): ...@@ -65,13 +65,29 @@ class Wan22StepDistillScheduler(WanStepDistillScheduler):
def set_denoising_timesteps(self, device: Union[str, torch.device] = None): def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
super().set_denoising_timesteps(device) super().set_denoising_timesteps(device)
self.sigma_boundary = self.sigmas[self.boundary_step_index].item() self.sigma_bound = self.sigmas[self.boundary_step_index].item()
def calculate_alpha_beta_high(self, sigma):
alpha = (1 - sigma) / (1 - self.sigma_bound)
beta = math.sqrt(sigma**2 - (alpha * self.sigma_bound) ** 2)
return alpha, beta
def step_post(self): def step_post(self):
flow_pred = self.noise_pred.to(torch.float32) flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item() sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred # self.latent: x_t
if self.step_index < self.infer_steps - 1: if self.step_index < self.boundary_step_index:
sigma_n = self.sigmas[self.step_index + 1].item() # noisy_image_or_video: x_500
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n alpha, beta = self.calculate_alpha_beta_high(sigma)
noisy_image_or_video = (self.latents.to(torch.float32) - beta * (1 - self.sigma_bound) * flow_pred) / (alpha + beta)
if self.step_index < self.boundary_step_index - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
alpha_n, beta_n = self.calculate_alpha_beta_high(sigma_n)
noisy_image_or_video = (alpha_n + beta_n) * noisy_image_or_video + (1 - self.sigma_bound) * beta_n * flow_pred
else:
# noisy_image_or_video: x_0
noisy_image_or_video = self.latents.to(torch.float32) - flow_pred * sigma
if self.step_index < self.infer_steps - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
self.latents = noisy_image_or_video.to(self.latents.dtype) self.latents = noisy_image_or_video.to(self.latents.dtype)
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.2_moe_distill \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_moe_i2v_distill.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_distill.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment