model.py 6.25 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import sys
helloyongyang's avatar
helloyongyang committed
3
import torch
4
5
6
7
8
9
10
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching
Dongz's avatar
Dongz committed
11

Xinchi Huang's avatar
Xinchi Huang committed
12
13
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
14
15
from lightx2v.utils.envs import *
from loguru import logger
helloyongyang's avatar
helloyongyang committed
16
17
18
19
20
21
22


class HunyuanModel:
    pre_weight_class = HunyuanPreWeights
    post_weight_class = HunyuanPostWeights
    transformer_weight_class = HunyuanTransformerWeights

helloyongyang's avatar
helloyongyang committed
23
    def __init__(self, model_path, config, device, args):
helloyongyang's avatar
helloyongyang committed
24
25
        self.model_path = model_path
        self.config = config
gushiqiao's avatar
gushiqiao committed
26
        self.device = device
helloyongyang's avatar
helloyongyang committed
27
        self.args = args
helloyongyang's avatar
helloyongyang committed
28
29
        self._init_infer_class()
        self._init_weights()
30
31
32
33
34
        if GET_RUNNING_FLAG() == "save_naive_quant":
            assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
            self.save_weights(self.config.naive_quant_path)
            sys.exit(0)

helloyongyang's avatar
helloyongyang committed
35
36
        self._init_infer()

Xinchi Huang's avatar
Xinchi Huang committed
37
38
39
40
41
42
43
        if config["parallel_attn_type"]:
            if config["parallel_attn_type"] == "ulysses":
                ulysses_dist_wrap.parallelize_hunyuan(self)
            elif config["parallel_attn_type"] == "ring":
                ring_dist_wrap.parallelize_hunyuan(self)
            else:
                raise Exception(f"Unsuppotred parallel_attn_type")
Dongz's avatar
Dongz committed
44
45

        if self.config["cpu_offload"]:
helloyongyang's avatar
helloyongyang committed
46
47
48
49
50
            self.to_cpu()

    def _init_infer_class(self):
        self.pre_infer_class = HunyuanPreInfer
        self.post_infer_class = HunyuanPostInfer
Dongz's avatar
Dongz committed
51
        if self.config["feature_caching"] == "NoCaching":
helloyongyang's avatar
helloyongyang committed
52
            self.transformer_infer_class = HunyuanTransformerInfer
Dongz's avatar
Dongz committed
53
        elif self.config["feature_caching"] == "TaylorSeer":
54
55
56
            self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
        elif self.config["feature_caching"] == "Tea":
            self.transformer_infer_class = HunyuanTransformerInferTeaCaching
helloyongyang's avatar
helloyongyang committed
57
58
59
60
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")

    def _load_ckpt(self):
helloyongyang's avatar
helloyongyang committed
61
62
63
64
        if self.args.task == "t2v":
            ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
        else:
            ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt")
gushiqiao's avatar
gushiqiao committed
65
        weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
helloyongyang's avatar
helloyongyang committed
66
67
        return weight_dict

68
69
70
71
72
73
74
    def _load_ckpt_quant_model(self):
        assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
        logger.info(f"Loading quant model from {self.config.naive_quant_path}")
        quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
        weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
        return weight_dict

helloyongyang's avatar
helloyongyang committed
75
    def _init_weights(self):
76
        if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
77
78
79
            weight_dict = self._load_ckpt()
        else:
            weight_dict = self._load_ckpt_quant_model()
helloyongyang's avatar
helloyongyang committed
80
81
82
83
84
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
        self.post_weight = self.post_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
85
86
87
        self.pre_weight.load(weight_dict)
        self.post_weight.load(weight_dict)
        self.transformer_weights.load(weight_dict)
helloyongyang's avatar
helloyongyang committed
88
89

    def _init_infer(self):
90
91
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
helloyongyang's avatar
helloyongyang committed
92
93
        self.transformer_infer = self.transformer_infer_class(self.config)

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def save_weights(self, save_path):
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        pre_state_dict = self.pre_weight.state_dict()
        logger.info(pre_state_dict.keys())

        post_state_dict = self.post_weight.state_dict()
        logger.info(post_state_dict.keys())

        transformer_state_dict = self.transformer_weights.state_dict()
        logger.info(transformer_state_dict.keys())

        save_dict = {}
        save_dict.update(pre_state_dict)
        save_dict.update(post_state_dict)
        save_dict.update(transformer_state_dict)

        save_path = os.path.join(save_path, "quant_weights.pth")
        torch.save(save_dict, save_path)
        logger.info(f"Save weights to {save_path}")

helloyongyang's avatar
helloyongyang committed
116
117
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
118
119
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.transformer_infer.set_scheduler(scheduler)

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

    @torch.no_grad()
133
    def infer(self, inputs):
gushiqiao's avatar
gushiqiao committed
134
135
136
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()
137
138
139
140
141

        inputs = self.pre_infer.infer(self.pre_weight, inputs)
        inputs = self.transformer_infer.infer(self.transformer_weights, *inputs)
        self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs)

gushiqiao's avatar
gushiqiao committed
142
143
        if self.config["cpu_offload"]:
            self.pre_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
144
            self.post_weight.to_cpu()
145
146
147
148
        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt == self.scheduler.num_steps:
                self.scheduler.cnt = 0