Commit 194e3bf4 authored by helloyongyang's avatar helloyongyang
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents d7206e69 a7d70484
{
"infer_steps": 6,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [4.0, 3.0],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"boundary_step_index": 4,
"denoising_step_list": [1000, 875, 750, 625, 500, 250],
"lora_configs": [
{
"name": "low_noise_model",
"path": "Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
"strength": 1.0
}
]
}
......@@ -41,6 +41,7 @@ def main():
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2_moe_distill",
],
default="wan2.1",
)
......
......@@ -32,6 +32,11 @@ try:
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
import gguf
except ImportError:
gguf = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
......@@ -661,6 +666,23 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def dequantize_func(self):
# TODO: implement dequantize_func
pass
@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
......@@ -38,7 +38,7 @@ def main():
"--model_cls",
type=str,
required=True,
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2"],
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2", "wan2.2_moe_distill"],
default="wan2.1",
)
......
......@@ -3,7 +3,7 @@ import os
import torch
from loguru import logger
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
......@@ -21,14 +21,27 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if self.config.get("enable_dynamic_cfg", False):
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
return weight_dict
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)
return super()._load_ckpt(use_bf16, skip_bf16)
class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel):
def __init__(self, model_path, config, device):
WanDistillModel.__init__(self, model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
ckpt_path = os.path.join(self.model_path, "distill_model.safetensors")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)
@torch.no_grad()
def infer(self, inputs):
return Wan22MoeModel.infer(self, inputs)
......@@ -140,15 +140,15 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_
"""
assert len(w.shape) == 1
cfg_min, cfg_max = cfg_range
w = torch.round(w)
w = torch.clamp(w, min=cfg_min, max=cfg_max)
# w = torch.round(w)
# w = torch.clamp(w, min=cfg_min, max=cfg_max)
w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1]
w = w * target_range
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
assert emb.shape == (w.shape[0], embedding_dim)
......
......@@ -31,6 +31,11 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
try:
import gguf
except ImportError:
gguf = None
class WanModel:
pre_weight_class = WanPreWeights
......@@ -48,7 +53,11 @@ class WanModel:
if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
self.dit_quantized_ckpt = find_hf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True
else:
self.dit_quantized_ckpt = find_hf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
if os.path.exists(quant_config_path):
with open(quant_config_path, "r") as f:
......@@ -155,6 +164,14 @@ class WanModel:
return pre_post_weight_dict
def _load_gguf_ckpt(self):
gguf_path = self.dit_quantized_ckpt
logger.info(f"Loading gguf-quant dit model from {gguf_path}")
reader = gguf.GGUFReader(gguf_path)
for tensor in reader.tensors:
# TODO: implement _load_gguf_ckpt
pass
def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy
......@@ -169,6 +186,8 @@ class WanModel:
if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
elif self.config.get("use_gguf", False):
self.original_weight_dict = self._load_gguf_ckpt()
else:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
......
import os
from loguru import logger
from lightx2v.models.networks.wan.distill_model import WanDistillModel
from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -37,3 +39,99 @@ class WanDistillRunner(WanRunner):
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
class MultiDistillModelStruct(MultiModelStruct):
def __init__(self, model_list, config, boundary_step_index=2):
self.model = model_list # [high_noise_model, low_noise_model]
assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
self.config = config
self.boundary_step_index = boundary_step_index
self.cur_model_index = -1
logger.info(f"boundary step index: {self.boundary_step_index}")
def get_current_model_index(self):
if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
elif self.cur_model_index == 1: # 1 -> 0
self.offload_cpu(model_index=1)
self.to_cuda(model_index=0)
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
elif self.cur_model_index == 0: # 0 -> 1
self.offload_cpu(model_index=0)
self.to_cuda(model_index=1)
self.cur_model_index = 1
@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
use_high_lora, use_low_lora = False, False
if self.config.get("lora_configs") and self.config.lora_configs:
for lora_config in self.config.lora_configs:
if lora_config.get("name", "") == "high_noise_model":
use_high_lora = True
elif lora_config.get("name", "") == "low_noise_model":
use_low_lora = True
if use_high_lora:
high_noise_model = Wan22MoeModel(
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs:
if lora_config.get("name", "") == "high_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = high_lora_wrapper.load_lora(lora_path)
high_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
high_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config,
self.init_device,
)
if use_low_lora:
low_noise_model = Wan22MoeModel(
os.path.join(self.config.model_path, "low_noise_model"),
self.config,
self.init_device,
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config.lora_configs:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = low_lora_wrapper.load_lora(lora_path)
low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "low_noise_model"),
self.config,
self.init_device,
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary_step_index)
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
scheduler = Wan22StepDistillScheduler(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......@@ -56,3 +56,24 @@ class WanStepDistillScheduler(WanScheduler):
noise = torch.randn(noisy_image_or_video.shape, dtype=torch.float32, device=self.device, generator=self.generator)
noisy_image_or_video = self.add_noise(noisy_image_or_video, noise=noise, sigma=self.sigmas[self.step_index + 1].item())
self.latents = noisy_image_or_video.to(self.latents.dtype)
class Wan22StepDistillScheduler(WanStepDistillScheduler):
def __init__(self, config):
super().__init__(config)
self.boundary_step_index = config.boundary_step_index
def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
super().set_denoising_timesteps(device)
self.sigma_boundary = self.sigmas[self.boundary_step_index].item()
def step_post(self):
flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.boundary_step_index:
noisy_image_or_video = noisy_image_or_video / self.sigma_boundary
if self.step_index < self.infer_steps - 1:
sigma = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = self.add_noise(noisy_image_or_video, torch.randn_like(noisy_image_or_video), self.sigmas[self.step_index + 1].item())
self.latents = noisy_image_or_video.to(self.latents.dtype)
......@@ -297,6 +297,34 @@ def find_hf_model_path(config, ckpt_config_key=None, subdir=["original", "fp8",
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
gguf_path = config.get(ckpt_config_key, None)
if gguf_path is None:
raise ValueError(f"GGUF path not found in config with key '{ckpt_config_key}'")
if not isinstance(gguf_path, str) or not gguf_path.endswith(".gguf"):
raise ValueError(f"GGUF path must be a string ending with '.gguf', got: {gguf_path}")
if os.sep in gguf_path or (os.altsep and os.altsep in gguf_path):
if os.path.exists(gguf_path):
logger.info(f"Found GGUF model file in: {gguf_path}")
return os.path.abspath(gguf_path)
else:
raise FileNotFoundError(f"GGUF file not found at path: {gguf_path}")
else:
# It's just a filename, search in predefined paths
paths_to_check = [config.model_path]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir))
for path in paths_to_check:
gguf_file_path = os.path.join(path, gguf_path)
gguf_file = glob.glob(gguf_file_path)
if gguf_file:
logger.info(f"Found GGUF model file in: {gguf_file_path}")
return gguf_file_path
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, torch.Tensor)
out = torch.ones_like(tensor)
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.2_moe_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_moe_t2v_distill.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_distill.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment