model.py 7.54 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import sys
helloyongyang's avatar
helloyongyang committed
3
4
import torch
import glob
5
6
7
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
helloyongyang's avatar
helloyongyang committed
8
9
    WanTransformerWeights,
)
10
11
12
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
helloyongyang's avatar
helloyongyang committed
13
14
    WanTransformerInfer,
)
15
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
helloyongyang's avatar
helloyongyang committed
16
from safetensors import safe_open
Xinchi Huang's avatar
Xinchi Huang committed
17
18
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
19
20
from lightx2v.utils.envs import *
from loguru import logger
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27


class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

gushiqiao's avatar
gushiqiao committed
28
    def __init__(self, model_path, config, device):
helloyongyang's avatar
helloyongyang committed
29
30
        self.model_path = model_path
        self.config = config
gushiqiao's avatar
gushiqiao committed
31
        self.device = device
helloyongyang's avatar
helloyongyang committed
32
33
        self._init_infer_class()
        self._init_weights()
34
35
36
37
38
        if GET_RUNNING_FLAG() == "save_naive_quant":
            assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
            self.save_weights(self.config.naive_quant_path)
            sys.exit(0)

helloyongyang's avatar
helloyongyang committed
39
        self._init_infer()
lijiaqi2's avatar
lijiaqi2 committed
40
        self.current_lora = None
helloyongyang's avatar
helloyongyang committed
41

Xinchi Huang's avatar
Xinchi Huang committed
42
43
44
45
46
47
48
        if config["parallel_attn_type"]:
            if config["parallel_attn_type"] == "ulysses":
                ulysses_dist_wrap.parallelize_wan(self)
            elif config["parallel_attn_type"] == "ring":
                ring_dist_wrap.parallelize_wan(self)
            else:
                raise Exception(f"Unsuppotred parallel_attn_type")
Xinchi Huang's avatar
Xinchi Huang committed
49

Dongz's avatar
Dongz committed
50
        if self.config["cpu_offload"]:
TorynCurtis's avatar
TorynCurtis committed
51
52
            self.to_cpu()

helloyongyang's avatar
helloyongyang committed
53
54
55
56
57
58
    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WanTransformerInfer
        elif self.config["feature_caching"] == "Tea":
59
            self.transformer_infer_class = WanTransformerInferTeaCaching
helloyongyang's avatar
helloyongyang committed
60
        else:
Dongz's avatar
Dongz committed
61
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
62
63

    def _load_safetensor_to_dict(self, file_path):
lijiaqi2's avatar
lijiaqi2 committed
64
        use_bfloat16 = self.config.get("use_bfloat16", True)
helloyongyang's avatar
helloyongyang committed
65
        with safe_open(file_path, framework="pt") as f:
lijiaqi2's avatar
lijiaqi2 committed
66
67
68
69
            if use_bfloat16:
                tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
            else:
                tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
helloyongyang's avatar
helloyongyang committed
70
71
72
73
74
75
76
        return tensor_dict

    def _load_ckpt(self):
        safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)

        if not safetensors_files:
Dongz's avatar
Dongz committed
77
            raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
helloyongyang's avatar
helloyongyang committed
78
79
80
81
82
83
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path)
            weight_dict.update(file_weights)
        return weight_dict

84
85
86
87
88
89
90
    def _load_ckpt_quant_model(self):
        assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
        logger.info(f"Loading quant model from {self.config.naive_quant_path}")
        quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
        weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
        return weight_dict

lijiaqi2's avatar
lijiaqi2 committed
91
92
    def _init_weights(self, weight_dict=None):
        if weight_dict is None:
93
            if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
94
95
96
                self.original_weight_dict = self._load_ckpt()
            else:
                self.original_weight_dict = self._load_ckpt_quant_model()
lijiaqi2's avatar
lijiaqi2 committed
97
98
        else:
            self.original_weight_dict = weight_dict
helloyongyang's avatar
helloyongyang committed
99
100
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
101
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
102
103
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
104
105
106
        self.pre_weight.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)
        self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
107
108
109
110
111
112

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        self.transformer_infer = self.transformer_infer_class(self.config)

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def save_weights(self, save_path):
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        pre_state_dict = self.pre_weight.state_dict()
        logger.info(pre_state_dict.keys())

        post_state_dict = self.post_weight.state_dict()
        logger.info(post_state_dict.keys())

        transformer_state_dict = self.transformer_weights.state_dict()
        logger.info(transformer_state_dict.keys())

        save_dict = {}
        save_dict.update(pre_state_dict)
        save_dict.update(post_state_dict)
        save_dict.update(transformer_state_dict)

        save_path = os.path.join(save_path, "quant_weights.pth")
        torch.save(save_dict, save_path)
        logger.info(f"Save weights to {save_path}")

helloyongyang's avatar
helloyongyang committed
135
136
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
137
138
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
139
140
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
141
142
143
144
145
146
147
148
149
150
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
151
    @torch.no_grad()
152
    def infer(self, inputs):
gushiqiao's avatar
gushiqiao committed
153
154
155
156
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

157
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
Dongz's avatar
Dongz committed
158
159
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
160
161
162
163
164

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0
root's avatar
root committed
165
        self.scheduler.noise_pred = noise_pred_cond
helloyongyang's avatar
helloyongyang committed
166

167
        if self.config["enable_cfg"]:
root's avatar
root committed
168
169
170
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
171

root's avatar
root committed
172
173
174
175
            if self.config["feature_caching"] == "Tea":
                self.scheduler.cnt += 1
                if self.scheduler.cnt >= self.scheduler.num_steps:
                    self.scheduler.cnt = 0
helloyongyang's avatar
helloyongyang committed
176

root's avatar
root committed
177
            self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
gushiqiao's avatar
gushiqiao committed
178

root's avatar
root committed
179
180
181
            if self.config["cpu_offload"]:
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()