model.py 6.93 KB
Newer Older
PengGao's avatar
PengGao committed
1
import json
helloyongyang's avatar
helloyongyang committed
2
import os
PengGao's avatar
PengGao committed
3

helloyongyang's avatar
helloyongyang committed
4
import torch
PengGao's avatar
PengGao committed
5
6
7
from loguru import logger
from safetensors import safe_open

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
8
9
10
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
    HunyuanTransformerInferAdaCaching,
    HunyuanTransformerInferCustomCaching,
PengGao's avatar
PengGao committed
11
12
    HunyuanTransformerInferTaylorCaching,
    HunyuanTransformerInferTeaCaching,
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
13
)
PengGao's avatar
PengGao committed
14
15
16
17
18
19
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
20
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27


class HunyuanModel:
    pre_weight_class = HunyuanPreWeights
    post_weight_class = HunyuanPostWeights
    transformer_weight_class = HunyuanTransformerWeights

helloyongyang's avatar
helloyongyang committed
28
    def __init__(self, model_path, config, device, args):
helloyongyang's avatar
helloyongyang committed
29
30
        self.model_path = model_path
        self.config = config
gushiqiao's avatar
gushiqiao committed
31
        self.device = device
helloyongyang's avatar
helloyongyang committed
32
        self.args = args
33
34
35
36
37
38
39

        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
        self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
        if self.dit_quantized:
            assert self.weight_auto_quant or self.dit_quantized_ckpt is not None

helloyongyang's avatar
helloyongyang committed
40
41
42
43
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

Dongz's avatar
Dongz committed
44
        if self.config["cpu_offload"]:
helloyongyang's avatar
helloyongyang committed
45
46
47
            self.to_cpu()

    def _load_ckpt(self):
helloyongyang's avatar
helloyongyang committed
48
49
50
51
        if self.args.task == "t2v":
            ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
        else:
            ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt")
gushiqiao's avatar
gushiqiao committed
52
        weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
helloyongyang's avatar
helloyongyang committed
53
54
        return weight_dict

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    def _load_quant_ckpt(self):
        ckpt_path = self.config.dit_quantized_ckpt
        logger.info(f"Loading quant dit model from {ckpt_path}")

        if ckpt_path.endswith(".pth"):
            logger.info(f"Loading {ckpt_path} as PyTorch model.")
            weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
        else:
            index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
            if not index_files:
                raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")

            index_path = os.path.join(ckpt_path, index_files[0])
            logger.info(f" Using safetensors index: {index_path}")

            with open(index_path, "r") as f:
                index_data = json.load(f)

            weight_dict = {}
            for filename in set(index_data["weight_map"].values()):
                safetensor_path = os.path.join(ckpt_path, filename)
                with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
                    logger.info(f"Loading weights from {safetensor_path}")
                    for k in f.keys():
                        weight_dict[k] = f.get_tensor(k)
                        if weight_dict[k].dtype == torch.float:
                            weight_dict[k] = weight_dict[k].to(torch.bfloat16)

83
84
        return weight_dict

helloyongyang's avatar
helloyongyang committed
85
    def _init_weights(self):
86
        if not self.dit_quantized or self.weight_auto_quant:
87
88
            weight_dict = self._load_ckpt()
        else:
89
            weight_dict = self._load_quant_ckpt()
helloyongyang's avatar
helloyongyang committed
90
91
92
93
94
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
        self.post_weight = self.post_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
95
96
97
        self.pre_weight.load(weight_dict)
        self.post_weight.load(weight_dict)
        self.transformer_weights.load(weight_dict)
helloyongyang's avatar
helloyongyang committed
98
99

    def _init_infer(self):
100
101
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
helloyongyang's avatar
helloyongyang committed
102
103
        self.transformer_infer = self.transformer_infer_class(self.config)

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    def save_weights(self, save_path):
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        pre_state_dict = self.pre_weight.state_dict()
        logger.info(pre_state_dict.keys())

        post_state_dict = self.post_weight.state_dict()
        logger.info(post_state_dict.keys())

        transformer_state_dict = self.transformer_weights.state_dict()
        logger.info(transformer_state_dict.keys())

        save_dict = {}
        save_dict.update(pre_state_dict)
        save_dict.update(post_state_dict)
        save_dict.update(transformer_state_dict)

        save_path = os.path.join(save_path, "quant_weights.pth")
        torch.save(save_dict, save_path)
        logger.info(f"Save weights to {save_path}")

helloyongyang's avatar
helloyongyang committed
126
127
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
128
129
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
        self.transformer_infer.set_scheduler(scheduler)

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

    @torch.no_grad()
143
    def infer(self, inputs):
gushiqiao's avatar
gushiqiao committed
144
145
146
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()
147
148
149
150
151

        inputs = self.pre_infer.infer(self.pre_weight, inputs)
        inputs = self.transformer_infer.infer(self.transformer_weights, *inputs)
        self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs)

gushiqiao's avatar
gushiqiao committed
152
153
        if self.config["cpu_offload"]:
            self.pre_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
154
            self.post_weight.to_cpu()
155
156
157
158
159
160
161
162
163
164

    def _init_infer_class(self):
        self.pre_infer_class = HunyuanPreInfer
        self.post_infer_class = HunyuanPostInfer
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = HunyuanTransformerInfer
        elif self.config["feature_caching"] == "TaylorSeer":
            self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
        elif self.config["feature_caching"] == "Tea":
            self.transformer_infer_class = HunyuanTransformerInferTeaCaching
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
165
166
167
168
        elif self.config["feature_caching"] == "Ada":
            self.transformer_infer_class = HunyuanTransformerInferAdaCaching
        elif self.config["feature_caching"] == "Custom":
            self.transformer_infer_class = HunyuanTransformerInferCustomCaching
169
170
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")