distill_model.py 1.78 KB
Newer Older
1
import os
PengGao's avatar
PengGao committed
2

3
import torch
PengGao's avatar
PengGao committed
4
5
from loguru import logger

6
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
7
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
8
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.utils.envs import *


class WanDistillModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

gushiqiao's avatar
Fix  
gushiqiao committed
23
    def _load_ckpt(self, use_bf16, skip_bf16):
24
25
26
27
        if self.config.get("enable_dynamic_cfg", False):
            ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
        else:
            ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
28
        if os.path.exists(ckpt_path):
gushiqiao's avatar
gushiqiao committed
29
            logger.info(f"Loading weights from {ckpt_path}")
30
            return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)
31

32
        return super()._load_ckpt(use_bf16, skip_bf16)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel):
    def __init__(self, model_path, config, device):
        WanDistillModel.__init__(self, model_path, config, device)

    def _load_ckpt(self, use_bf16, skip_bf16):
        ckpt_path = os.path.join(self.model_path, "distill_model.safetensors")
        if os.path.exists(ckpt_path):
            logger.info(f"Loading weights from {ckpt_path}")
            return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)

    @torch.no_grad()
    def infer(self, inputs):
        return Wan22MoeModel.infer(self, inputs)