distill_model.py 1.97 KB
Newer Older
1
2
3
4
5
import os
import sys
import torch
import glob
import json
6
from safetensors import safe_open
7
8
9
10
11
12
13
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.utils.envs import *
gushiqiao's avatar
gushiqiao committed
14
from loguru import logger
15
16
17
18
19
20
21
22
23
24


class WanDistillModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

gushiqiao's avatar
Fix  
gushiqiao committed
25
    def _load_ckpt(self, use_bf16, skip_bf16):
26
27
28
29
30
31
32
        enable_dynamic_cfg = self.config.get("enable_dynamic_cfg", False)
        ckpt_folder = "distill_cfg_models" if enable_dynamic_cfg else "distill_models"
        safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.safetensors")
        if os.path.exists(safetensors_path):
            with safe_open(safetensors_path, framework="pt") as f:
                weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
                return weight_dict
33

34
        ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
gushiqiao's avatar
gushiqiao committed
35

36
        if os.path.exists(ckpt_path):
gushiqiao's avatar
gushiqiao committed
37
            logger.info(f"Loading weights from {ckpt_path}")
38
            weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
gushiqiao's avatar
gushiqiao committed
39
            print(weight_dict.keys())
40
41
42
43
            weight_dict = {
                key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
            }
            return weight_dict
44

45
        return super()._load_ckpt(use_bf16, skip_bf16)