distill_model.py 1.47 KB
Newer Older
1
import os
PengGao's avatar
PengGao committed
2

3
import torch
PengGao's avatar
PengGao committed
4
5
from loguru import logger

helloyongyang's avatar
helloyongyang committed
6
from lightx2v.models.networks.wan.model import WanModel
7
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
8
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
9
10
11
12
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.utils.envs import *
gushiqiao's avatar
gushiqiao committed
13
from lightx2v.utils.utils import *
14
15
16
17
18
19
20


class WanDistillModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

21
    def __init__(self, model_path, config, device):
helloyongyang's avatar
helloyongyang committed
22
        super().__init__(model_path, config, device)
23

24
    def _load_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
25
26
        # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
        ckpt_path = os.path.join(self.model_path, "distill_model.pt")
27
        if os.path.exists(ckpt_path):
gushiqiao's avatar
gushiqiao committed
28
            logger.info(f"Loading weights from {ckpt_path}")
gushiqiao's avatar
gushiqiao committed
29
30
31
32
33
34
            weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
            weight_dict = {
                key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in weight_dict.keys()
            }
            return weight_dict
35
        return super()._load_ckpt(unified_dtype, sensitive_layer)