Commit fa88888a authored by gushiqiao's avatar gushiqiao
Browse files

Reconstruct load model code and fix bugs

parent 87bbed1c
......@@ -23,16 +23,8 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
enable_dynamic_cfg = self.config.get("enable_dynamic_cfg", False)
ckpt_folder = "distill_cfg_models" if enable_dynamic_cfg else "distill_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
......@@ -105,7 +105,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num):
......
......@@ -22,6 +22,7 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
from loguru import logger
......@@ -34,13 +35,15 @@ class WanModel:
self.model_path = model_path
self.config = config
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme))
self.dit_quantized_ckpt = find_hf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
else:
self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False)
self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
......@@ -80,16 +83,8 @@ class WanModel:
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
original_pattern = os.path.join(self.model_path, "original", "*.safetensors")
safetensors_files = glob.glob(original_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
......
......@@ -33,6 +33,8 @@ class DefaultRunner(BaseRunner):
logger.info("Initializing runner modules...")
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "i2v":
......
......@@ -17,6 +17,7 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerCustomCaching,
)
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import *
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
......@@ -58,28 +59,24 @@ class WanRunner(DefaultRunner):
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{tmp_clip_quant_scheme}.pth",
),
)
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
clip_original_ckpt = None
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
image_encoder = CLIPModel(
dtype=torch.float16,
device=self.init_device,
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
checkpoint_path=clip_original_ckpt,
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
)
return image_encoder
def load_text_encoder(self):
......@@ -94,24 +91,22 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, tmp_t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
),
)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
t5_original_ckpt = None
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=t5_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
checkpoint_path=t5_original_ckpt,
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
cpu_offload=t5_offload,
......@@ -125,7 +120,7 @@ class WanRunner(DefaultRunner):
def load_vae_encoder(self):
vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
......@@ -137,13 +132,13 @@ class WanRunner(DefaultRunner):
def load_vae_decoder(self):
vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
"device": self.init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = self.config.get("tiny_vae_path", os.path.join(self.config.model_path, "taew2_1.pth"))
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth", "original")
vae_decoder = WanVAE_tiny(
vae_pth=tiny_vae_path,
device=self.init_device,
......@@ -216,7 +211,10 @@ class WanRunner(DefaultRunner):
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2
lat_h, lat_w = (
int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
)
vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list
......
import os
import random
import subprocess
from typing import Optional
from einops import rearrange
import glob
import imageio
import imageio_ffmpeg as ffmpeg
from loguru import logger
import numpy as np
import torch
import torchvision
from typing import Optional
from einops import rearrange
from loguru import logger
def seed_all(seed):
......@@ -154,12 +154,14 @@ def save_to_video(
if method == "imageio":
# Convert to uint8
frames = (images * 255).cpu().numpy().astype(np.uint8)
# frames = (images * 255).cpu().numpy().astype(np.uint8)
frames = (images * 255).to(torch.uint8).cpu().numpy()
imageio.mimsave(output_path, frames, fps=fps) # type: ignore
elif method == "ffmpeg":
# Convert to numpy and scale to [0, 255]
frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
# frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
frames = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
# Convert RGB to BGR for OpenCV/FFmpeg
frames = frames[..., ::-1].copy()
......@@ -252,3 +254,36 @@ def save_to_video(
else:
raise ValueError(f"Unknown save method: {method}")
def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=None):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [
os.path.join(config.model_path, filename),
]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir, filename))
for path in paths_to_check:
if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, ckpt_config_key=None, subdir=None):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [config.model_path]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment