distill_model.py 2.32 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
import glob
2
import os
PengGao's avatar
PengGao committed
3

4
import torch
PengGao's avatar
PengGao committed
5
6
from loguru import logger

helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.networks.wan.model import WanModel
8
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
9
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
10
11
12
13
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.utils.envs import *
gushiqiao's avatar
gushiqiao committed
14
from lightx2v.utils.utils import *
15
16
17
18
19
20
21
22
23
24


class WanDistillModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

25
    def _load_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
26
27
        # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
        ckpt_path = os.path.join(self.model_path, "distill_model.pt")
28
        if os.path.exists(ckpt_path):
gushiqiao's avatar
gushiqiao committed
29
            logger.info(f"Loading weights from {ckpt_path}")
gushiqiao's avatar
gushiqiao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
            weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
            weight_dict = {
                key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in weight_dict.keys()
            }
            return weight_dict

        if self.config.get("enable_dynamic_cfg", False):
            safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models")
        else:
            safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models")

        safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)
47

gushiqiao's avatar
gushiqiao committed
48
        return weight_dict
49
50


helloyongyang's avatar
helloyongyang committed
51
class Wan22MoeDistillModel(WanDistillModel, WanModel):
52
53
54
55
56
    def __init__(self, model_path, config, device):
        WanDistillModel.__init__(self, model_path, config, device)

    @torch.no_grad()
    def infer(self, inputs):
helloyongyang's avatar
helloyongyang committed
57
        return WanModel.infer(self, inputs)