"tests/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3a42ebbf5781d6c6408324edeac9d704ca41e6b6"
Commit 230e786a authored by gushiqiao's avatar gushiqiao
Browse files

Fix distill model load bug

parent 53e38505
import glob
import os import os
import torch import torch
...@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
...@@ -21,15 +23,29 @@ class WanDistillModel(WanModel): ...@@ -21,15 +23,29 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False): # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors") ckpt_path = os.path.join(self.model_path, "distill_model.pt")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return super()._load_ckpt(unified_dtype, sensitive_layer) return weight_dict
class Wan22MoeDistillModel(WanDistillModel, WanModel): class Wan22MoeDistillModel(WanDistillModel, WanModel):
......
...@@ -52,6 +52,8 @@ class WanModel: ...@@ -52,6 +52,8 @@ class WanModel:
if self.dit_quantized: if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1] dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf": if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True self.config.use_gguf = True
......
...@@ -286,7 +286,7 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin ...@@ -286,7 +286,7 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
for sub in subdir: for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub)) paths_to_check.append(os.path.join(model_path, sub))
else: else:
paths_to_check.append(os.path.join(config.model_path, subdir)) paths_to_check.append(os.path.join(model_path, subdir))
for path in paths_to_check: for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors") safetensors_pattern = os.path.join(path, "*.safetensors")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment