Commit ec79c145 authored by helloyongyang's avatar helloyongyang
Browse files

Support wan2.2 moe t2v model

parent 6e46224f
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.0, 4.0],
"sample_shift": 12.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.875
}
......@@ -40,8 +40,9 @@ def main():
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
"wan2.2_moe",
],
default="hunyuan",
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
......
......@@ -10,7 +10,7 @@ from lightx2v.utils.set_config import set_config
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner, Wan22MoeRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner
......@@ -42,7 +42,7 @@ def init_runner(config):
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="wan2.1"
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe"], default="wan2.1"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
......
......@@ -46,22 +46,25 @@ class WanTransformerInfer(BaseTransformerInfer):
self.infer_func = self._infer_with_phases_offload
else:
self.infer_func = self._infer_with_phases_lazy_offload
elif offload_granularity == "model":
self.infer_func = self._infer_without_offload
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
if offload_granularity != "model":
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
else:
self.infer_func = self._infer_without_offload
......
......@@ -226,7 +226,7 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config.get("cpu_offload", False):
self.pre_weight.to_cpu()
......@@ -235,3 +235,28 @@ class WanModel:
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
weight_dict.update(file_weights)
return weight_dict
@torch.no_grad()
def infer(self, inputs):
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
......@@ -18,7 +18,7 @@ from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import *
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.model import WanModel, Wan22MoeModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
......@@ -293,3 +293,69 @@ class WanRunner(DefaultRunner):
normalize=True,
value_range=(-1, 1),
)
class MultiModelStruct:
def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
self.model = model_list # [high_noise_model, low_noise_model]
assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
self.config = config
self.boundary = boundary
self.boundary_timestep = self.boundary * num_train_timesteps
self.cur_model_index = -1
logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")
def set_scheduler(self, shared_scheduler):
self.scheduler = shared_scheduler
for model in self.model:
model.set_scheduler(shared_scheduler)
def infer(self, inputs):
self.get_current_model_index()
self.model[self.cur_model_index].infer(inputs)
def get_current_model_index(self):
if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
elif self.cur_model_index == 1: # 1 -> 0
self.offload_cpu(model_index=1)
self.to_cuda(model_index=0)
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
elif self.cur_model_index == 0: # 0 -> 1
self.offload_cpu(model_index=0)
self.to_cuda(model_index=1)
self.cur_model_index = 1
def offload_cpu(self, model_index):
self.model[model_index].to_cpu()
def to_cuda(self, model_index):
self.model[model_index].to_cuda()
@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = Wan22MoeModel(
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
)
low_noise_model = Wan22MoeModel(
os.path.join(self.config.model_path, "low_noise_model"),
self.config,
self.init_device,
)
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
......@@ -18,6 +18,7 @@ class WanScheduler(BaseScheduler):
self.disable_corrector = []
self.solver_order = 2
self.noise_pred = None
self.sample_guide_scale = self.config.sample_guide_scale
self.caching_records_2 = [True] * self.config.infer_steps
......
......@@ -37,6 +37,10 @@ def set_config(args):
with open(os.path.join(config.model_path, "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config.model_path, "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "original", "config.json")):
with open(os.path.join(config.model_path, "original", "config.json"), "r") as f:
model_config = json.load(f)
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.2_moe \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment