distill_model.py 1.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
import sys
import torch
import glob
import json
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.utils.envs import *


class WanDistillModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _load_ckpt(self):
gushiqiao's avatar
gushiqiao committed
24
        use_bfloat16 = GET_DTYPE() == "BF16"
25
26
27
28
29
30
31
32
33
34
35
36
        ckpt_path = os.path.join(self.model_path, "distill_model.pt")
        if not os.path.exists(ckpt_path):
            # 文件不存在,调用父类的 _load_ckpt 方法
            return super()._load_ckpt()

        weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)

        dtype = torch.bfloat16 if use_bfloat16 else None
        for key, value in weight_dict.items():
            weight_dict[key] = value.to(device=self.device, dtype=dtype)

        return weight_dict