Commit 38d11b82 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Reconstruct load model code and fix bugs

Reconstruct load model code and fix bugs
parents ab363647 fa88888a
...@@ -23,16 +23,8 @@ class WanDistillModel(WanModel): ...@@ -23,16 +23,8 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, use_bf16, skip_bf16):
enable_dynamic_cfg = self.config.get("enable_dynamic_cfg", False) # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_folder = "distill_cfg_models" if enable_dynamic_cfg else "distill_models" ckpt_path = os.path.join(self.model_path, "distill_model.pt")
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
...@@ -105,7 +105,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -105,7 +105,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks) self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
......
...@@ -22,6 +22,7 @@ from safetensors import safe_open ...@@ -22,6 +22,7 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
from loguru import logger from loguru import logger
...@@ -34,13 +35,15 @@ class WanModel: ...@@ -34,13 +35,15 @@ class WanModel:
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
if self.dit_quantized: if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1] dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme)) self.dit_quantized_ckpt = find_hf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
else: else:
self.dit_quantized_ckpt = None self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False)
self.config.dit_quantized_ckpt = self.dit_quantized_ckpt self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
...@@ -80,16 +83,8 @@ class WanModel: ...@@ -80,16 +83,8 @@ class WanModel:
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors") safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
if not safetensors_files:
original_pattern = os.path.join(self.model_path, "original", "*.safetensors")
safetensors_files = glob.glob(original_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16) file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
......
...@@ -33,6 +33,8 @@ class DefaultRunner(BaseRunner): ...@@ -33,6 +33,8 @@ class DefaultRunner(BaseRunner):
logger.info("Initializing runner modules...") logger.info("Initializing runner modules...")
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model() self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "i2v": if self.config["task"] == "i2v":
......
...@@ -17,6 +17,7 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( ...@@ -17,6 +17,7 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerCustomCaching, WanSchedulerCustomCaching,
) )
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import *
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
...@@ -58,28 +59,24 @@ class WanRunner(DefaultRunner): ...@@ -58,28 +59,24 @@ class WanRunner(DefaultRunner):
clip_quant_scheme = self.config.get("clip_quant_scheme", None) clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get( clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
"clip_quantized_ckpt", clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
os.path.join( clip_original_ckpt = None
os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{tmp_clip_quant_scheme}.pth",
),
)
else: else:
clip_quantized_ckpt = None clip_quantized_ckpt = None
clip_quant_scheme = None clip_quant_scheme = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=self.init_device, device=self.init_device,
checkpoint_path=os.path.join( checkpoint_path=clip_original_ckpt,
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
clip_quantized=clip_quantized, clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt, clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme, quant_scheme=clip_quant_scheme,
) )
return image_encoder return image_encoder
def load_text_encoder(self): def load_text_encoder(self):
...@@ -94,24 +91,22 @@ class WanRunner(DefaultRunner): ...@@ -94,24 +91,22 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False) t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized: if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None) t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get( tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
"t5_quantized_ckpt", t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
os.path.join( t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
os.path.join(self.config.model_path, tmp_t5_quant_scheme), t5_original_ckpt = None
f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
),
)
else: else:
t5_quant_scheme = None t5_quant_scheme = None
t5_quantized_ckpt = None t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=t5_device, device=t5_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=t5_original_ckpt,
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
cpu_offload=t5_offload, cpu_offload=t5_offload,
...@@ -125,7 +120,7 @@ class WanRunner(DefaultRunner): ...@@ -125,7 +120,7 @@ class WanRunner(DefaultRunner):
def load_vae_encoder(self): def load_vae_encoder(self):
vae_config = { vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
"device": self.init_device, "device": self.init_device,
"parallel": self.config.parallel_vae, "parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
...@@ -137,13 +132,13 @@ class WanRunner(DefaultRunner): ...@@ -137,13 +132,13 @@ class WanRunner(DefaultRunner):
def load_vae_decoder(self): def load_vae_decoder(self):
vae_config = { vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
"device": self.init_device, "device": self.init_device,
"parallel": self.config.parallel_vae, "parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
} }
if self.config.get("use_tiny_vae", False): if self.config.get("use_tiny_vae", False):
tiny_vae_path = self.config.get("tiny_vae_path", os.path.join(self.config.model_path, "taew2_1.pth")) tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth", "original")
vae_decoder = WanVAE_tiny( vae_decoder = WanVAE_tiny(
vae_pth=tiny_vae_path, vae_pth=tiny_vae_path,
device=self.init_device, device=self.init_device,
...@@ -216,7 +211,10 @@ class WanRunner(DefaultRunner): ...@@ -216,7 +211,10 @@ class WanRunner(DefaultRunner):
self.config.lat_h, self.config.lat_w = lat_h, lat_w self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = [] vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])): for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2 lat_h, lat_w = (
int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
)
vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w)) vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w)) vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list return vae_encode_out_list
......
import os import os
import random import random
import subprocess import subprocess
from typing import Optional import glob
from einops import rearrange
import imageio import imageio
import imageio_ffmpeg as ffmpeg import imageio_ffmpeg as ffmpeg
from loguru import logger
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from typing import Optional
from einops import rearrange
from loguru import logger
def seed_all(seed): def seed_all(seed):
...@@ -154,12 +154,14 @@ def save_to_video( ...@@ -154,12 +154,14 @@ def save_to_video(
if method == "imageio": if method == "imageio":
# Convert to uint8 # Convert to uint8
frames = (images * 255).cpu().numpy().astype(np.uint8) # frames = (images * 255).cpu().numpy().astype(np.uint8)
frames = (images * 255).to(torch.uint8).cpu().numpy()
imageio.mimsave(output_path, frames, fps=fps) # type: ignore imageio.mimsave(output_path, frames, fps=fps) # type: ignore
elif method == "ffmpeg": elif method == "ffmpeg":
# Convert to numpy and scale to [0, 255] # Convert to numpy and scale to [0, 255]
frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8) # frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
frames = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
# Convert RGB to BGR for OpenCV/FFmpeg # Convert RGB to BGR for OpenCV/FFmpeg
frames = frames[..., ::-1].copy() frames = frames[..., ::-1].copy()
...@@ -252,3 +254,36 @@ def save_to_video( ...@@ -252,3 +254,36 @@ def save_to_video(
else: else:
raise ValueError(f"Unknown save method: {method}") raise ValueError(f"Unknown save method: {method}")
def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=None):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [
os.path.join(config.model_path, filename),
]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir, filename))
for path in paths_to_check:
if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, ckpt_config_key=None, subdir=None):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [config.model_path]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment