Commit 230e786a authored by gushiqiao's avatar gushiqiao
Browse files

Fix distill model load bug

parent 53e38505
import glob
import os
import torch
......@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanDistillModel(WanModel):
......@@ -21,15 +23,29 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False):
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return super()._load_ckpt(unified_dtype, sensitive_layer)
return weight_dict
class Wan22MoeDistillModel(WanDistillModel, WanModel):
......
......@@ -52,6 +52,8 @@ class WanModel:
if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True
......
......@@ -286,7 +286,7 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub))
else:
paths_to_check.append(os.path.join(config.model_path, subdir))
paths_to_check.append(os.path.join(model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment