matrix_game2_model.py 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
import os

import torch
from safetensors import safe_open

from lightx2v.models.networks.wan.infer.matrix_game2.pre_infer import WanMtxg2PreInfer
from lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer import WanMtxg2TransformerInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.networks.wan.weights.matrix_game2.pre_weights import WanMtxg2PreWeights
from lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights import WanActionTransformerWeights
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *


class WanSFMtxg2Model(WanSFModel):
    pre_weight_class = WanMtxg2PreWeights
    transformer_weight_class = WanActionTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
        with safe_open(file_path, framework="pt", device=str(self.device)) as f:
            return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}

    def _load_ckpt(self, unified_dtype, sensitive_layer):
        file_path = os.path.join(self.config["model_path"], f"{self.config['sub_model_folder']}/{self.config['sub_model_name']}")
        _weight_dict = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
        weight_dict = {}
        for k, v in _weight_dict.items():
            name = k[6:]
            weight = v.to(torch.bfloat16).to(self.device)
            weight_dict.update({name: weight})
        del _weight_dict
        return weight_dict

    def _init_infer_class(self):
        # update config by real model config
        with open(os.path.join(self.config["model_path"], self.config["sub_model_folder"], "config.json")) as f:
            model_config = json.load(f)
        for k in model_config.keys():
            self.config[k] = model_config[k]

        self.pre_infer_class = WanMtxg2PreInfer
        self.post_infer_class = WanPostInfer
        self.transformer_infer_class = WanMtxg2TransformerInfer