Commit 978e3b32 authored by helloyongyang's avatar helloyongyang
Browse files

update wan2.2moe

parent d50b8884
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [4.0, 3.0],
"sample_shift": 12.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.875,
"parallel": {
"cfg_p_size": 2
}
}
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import torch import torch
from loguru import logger from loguru import logger
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
...@@ -32,16 +32,10 @@ class WanDistillModel(WanModel): ...@@ -32,16 +32,10 @@ class WanDistillModel(WanModel):
return super()._load_ckpt(unified_dtype, sensitive_layer) return super()._load_ckpt(unified_dtype, sensitive_layer)
class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel): class Wan22MoeDistillModel(WanDistillModel, WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
WanDistillModel.__init__(self, model_path, config, device) WanDistillModel.__init__(self, model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = os.path.join(self.model_path, "distill_model.safetensors")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
return Wan22MoeModel.infer(self, inputs) return WanModel.infer(self, inputs)
...@@ -55,7 +55,7 @@ class WanModel: ...@@ -55,7 +55,7 @@ class WanModel:
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True self.config.use_gguf = True
else: else:
self.dit_quantized_ckpt = find_hf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_hf_model_path(config, self.model_path, "dit_quantized_ckpt", subdir=dit_quant_scheme)
quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json") quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
with open(quant_config_path, "r") as f: with open(quant_config_path, "r") as f:
...@@ -106,7 +106,7 @@ class WanModel: ...@@ -106,7 +106,7 @@ class WanModel:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
...@@ -293,36 +293,3 @@ class WanModel: ...@@ -293,36 +293,3 @@ class WanModel:
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0 noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
class Wan22MoeModel(WanModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
...@@ -4,7 +4,7 @@ from loguru import logger ...@@ -4,7 +4,7 @@ from loguru import logger
from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -86,7 +86,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -86,7 +86,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
use_low_lora = True use_low_lora = True
if use_high_lora: if use_high_lora:
high_noise_model = Wan22MoeModel( high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config.model_path, "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
...@@ -107,7 +107,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -107,7 +107,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
) )
if use_low_lora: if use_low_lora:
low_noise_model = Wan22MoeModel( low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"), os.path.join(self.config.model_path, "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
......
...@@ -11,7 +11,7 @@ from loguru import logger ...@@ -11,7 +11,7 @@ from loguru import logger
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import ( from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
WanScheduler4ChangingResolutionInterface, WanScheduler4ChangingResolutionInterface,
...@@ -370,12 +370,12 @@ class Wan22MoeRunner(WanRunner): ...@@ -370,12 +370,12 @@ class Wan22MoeRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = Wan22MoeModel( high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config.model_path, "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
low_noise_model = Wan22MoeModel( low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"), os.path.join(self.config.model_path, "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
......
...@@ -277,14 +277,14 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=[" ...@@ -277,14 +277,14 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, ckpt_config_key=None, subdir=["original", "fp8", "int8"]): def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None: if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key) return config.get(ckpt_config_key)
paths_to_check = [config.model_path] paths_to_check = [model_path]
if isinstance(subdir, list): if isinstance(subdir, list):
for sub in subdir: for sub in subdir:
paths_to_check.append(os.path.join(config.model_path, sub)) paths_to_check.append(os.path.join(model_path, sub))
else: else:
paths_to_check.append(os.path.join(config.model_path, subdir)) paths_to_check.append(os.path.join(config.model_path, subdir))
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
torchrun --nproc_per_node=2 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_t2v_cfg.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_parallel_cfg.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment