Commit 2dc83b97 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #195 from ModelTC/dev_distill

Dev distill
parents 6c328a16 4e895b2a
import os import os
import torch import torch
from safetensors import safe_open from LightX2V.lightx2v.utils.utils import find_torch_model_path
from lightx2v.common.ops.attn.radial_attn import MaskMap from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import ( from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
...@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel): ...@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models" ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt")
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = { weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
} }
return weight_dict return weight_dict
......
import glob
import os import os
import torch import torch
...@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
...@@ -21,15 +23,29 @@ class WanDistillModel(WanModel): ...@@ -21,15 +23,29 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False): # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors") ckpt_path = os.path.join(self.model_path, "distill_model.pt")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return super()._load_ckpt(unified_dtype, sensitive_layer) return weight_dict
class Wan22MoeDistillModel(WanDistillModel, WanModel): class Wan22MoeDistillModel(WanDistillModel, WanModel):
......
...@@ -28,14 +28,8 @@ class WanLoraWrapper: ...@@ -28,14 +28,8 @@ class WanLoraWrapper:
return lora_name return lora_name
def _load_lora_file(self, file_path): def _load_lora_file(self, file_path):
use_bfloat16 = GET_DTYPE() == "BF16"
if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
if use_bfloat16: tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
return tensor_dict return tensor_dict
def apply_lora(self, lora_name, alpha=1.0): def apply_lora(self, lora_name, alpha=1.0):
...@@ -52,7 +46,7 @@ class WanLoraWrapper: ...@@ -52,7 +46,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict) self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存 del lora_weights
return True return True
@torch.no_grad() @torch.no_grad()
......
...@@ -52,6 +52,8 @@ class WanModel: ...@@ -52,6 +52,8 @@ class WanModel:
if self.dit_quantized: if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1] dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf": if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True self.config.use_gguf = True
......
...@@ -286,7 +286,7 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin ...@@ -286,7 +286,7 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
for sub in subdir: for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub)) paths_to_check.append(os.path.join(model_path, sub))
else: else:
paths_to_check.append(os.path.join(config.model_path, subdir)) paths_to_check.append(os.path.join(model_path, subdir))
for path in paths_to_check: for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors") safetensors_pattern = os.path.join(path, "*.safetensors")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment