model.py 9.53 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import sys
helloyongyang's avatar
helloyongyang committed
3
4
import torch
import glob
5
import json
6
7
8
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
helloyongyang's avatar
helloyongyang committed
9
10
    WanTransformerWeights,
)
11
12
13
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
helloyongyang's avatar
helloyongyang committed
14
15
    WanTransformerInfer,
)
16
17
18
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
    WanTransformerInferTeaCaching,
)
helloyongyang's avatar
helloyongyang committed
19
from safetensors import safe_open
Xinchi Huang's avatar
Xinchi Huang committed
20
21
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
22
23
from lightx2v.utils.envs import *
from loguru import logger
helloyongyang's avatar
helloyongyang committed
24
25
26
27
28
29
30


class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

gushiqiao's avatar
gushiqiao committed
31
    def __init__(self, model_path, config, device):
helloyongyang's avatar
helloyongyang committed
32
33
        self.model_path = model_path
        self.config = config
34
35
36
37
38
39
40

        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
        self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
        if self.dit_quantized:
            assert self.weight_auto_quant or self.dit_quantized_ckpt is not None

gushiqiao's avatar
gushiqiao committed
41
        self.device = device
helloyongyang's avatar
helloyongyang committed
42
43
44
        self._init_infer_class()
        self._init_weights()
        self._init_infer()
lijiaqi2's avatar
lijiaqi2 committed
45
        self.current_lora = None
helloyongyang's avatar
helloyongyang committed
46

Xinchi Huang's avatar
Xinchi Huang committed
47
48
49
50
51
52
53
        if config["parallel_attn_type"]:
            if config["parallel_attn_type"] == "ulysses":
                ulysses_dist_wrap.parallelize_wan(self)
            elif config["parallel_attn_type"] == "ring":
                ring_dist_wrap.parallelize_wan(self)
            else:
                raise Exception(f"Unsuppotred parallel_attn_type")
Xinchi Huang's avatar
Xinchi Huang committed
54

helloyongyang's avatar
helloyongyang committed
55
56
57
58
59
60
    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WanTransformerInfer
        elif self.config["feature_caching"] == "Tea":
61
            self.transformer_infer_class = WanTransformerInferTeaCaching
helloyongyang's avatar
helloyongyang committed
62
        else:
Dongz's avatar
Dongz committed
63
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
64
65

    def _load_safetensor_to_dict(self, file_path):
lijiaqi2's avatar
lijiaqi2 committed
66
        use_bfloat16 = self.config.get("use_bfloat16", True)
helloyongyang's avatar
helloyongyang committed
67
        with safe_open(file_path, framework="pt") as f:
lijiaqi2's avatar
lijiaqi2 committed
68
            if use_bfloat16:
69
                tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
lijiaqi2's avatar
lijiaqi2 committed
70
            else:
71
                tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
helloyongyang's avatar
helloyongyang committed
72
73
74
75
76
77
78
        return tensor_dict

    def _load_ckpt(self):
        safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)

        if not safetensors_files:
Dongz's avatar
Dongz committed
79
            raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
helloyongyang's avatar
helloyongyang committed
80
81
82
83
84
85
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path)
            weight_dict.update(file_weights)
        return weight_dict

86
    def _load_quant_ckpt(self):
87
88
        ckpt_path = self.config.dit_quantized_ckpt
        logger.info(f"Loading quant dit model from {ckpt_path}")
89

90
91
92
        if ckpt_path.endswith(".pth"):
            logger.info(f"Loading {ckpt_path} as PyTorch model.")
            weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
93
94
95
        else:
            index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
            if not index_files:
96
                raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
97
98

            index_path = os.path.join(ckpt_path, index_files[0])
99
            logger.info(f" Using safetensors index: {index_path}")
100
101
102
103
104
105
106

            with open(index_path, "r") as f:
                index_data = json.load(f)

            weight_dict = {}
            for filename in set(index_data["weight_map"].values()):
                safetensor_path = os.path.join(ckpt_path, filename)
107
108
109
110
111
112
                with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
                    logger.info(f"Loading weights from {safetensor_path}")
                    for k in f.keys():
                        weight_dict[k] = f.get_tensor(k)
                        if weight_dict[k].dtype == torch.float:
                            weight_dict[k] = weight_dict[k].to(torch.bfloat16)
113

114
115
        return weight_dict

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def _load_quant_split_ckpt(self):
        lazy_load_model_path = self.config.dit_quantized_ckpt
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
        pre_post_weight_dict, transformer_weight_dict = {}, {}

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
        with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
            for k in f.keys():
                pre_post_weight_dict[k] = f.get_tensor(k)
                if pre_post_weight_dict[k].dtype == torch.float:
                    pre_post_weight_dict[k] = pre_post_weight_dict[k].to(torch.bfloat16)

        safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)
        if not safetensors_files:
            raise FileNotFoundError(f"No .safetensors files found in directory: {lazy_load_model_path}")

        for file_path in safetensors_files:
            with safe_open(file_path, framework="pt") as f:
                for k in f.keys():
                    if "modulation" in k:
                        transformer_weight_dict[k] = f.get_tensor(k)
                        if transformer_weight_dict[k].dtype == torch.float:
                            transformer_weight_dict[k] = transformer_weight_dict[k].to(torch.bfloat16)

        return pre_post_weight_dict, transformer_weight_dict

lijiaqi2's avatar
lijiaqi2 committed
143
144
    def _init_weights(self, weight_dict=None):
        if weight_dict is None:
145
            if not self.dit_quantized or self.weight_auto_quant:
146
147
                self.original_weight_dict = self._load_ckpt()
            else:
148
149
150
151
152
153
154
                if not self.config.get("lazy_load", False):
                    self.original_weight_dict = self._load_quant_ckpt()
                else:
                    (
                        self.original_weight_dict,
                        self.transformer_weight_dict,
                    ) = self._load_quant_split_ckpt()
lijiaqi2's avatar
lijiaqi2 committed
155
156
        else:
            self.original_weight_dict = weight_dict
157

helloyongyang's avatar
helloyongyang committed
158
159
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
160
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
161
162
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
163
164
        self.pre_weight.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)
165
166
167
168
        if hasattr(self, "transformer_weight_dict"):
            self.transformer_weights.load(self.transformer_weight_dict)
        else:
            self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
169
170
171
172
173
174
175
176

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        self.transformer_infer = self.transformer_infer_class(self.config)

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
177
178
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
179
180
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
181
182
183
184
185
186
187
188
189
190
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
191
    @torch.no_grad()
192
    def infer(self, inputs):
gushiqiao's avatar
gushiqiao committed
193
194
195
196
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

197
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
gushiqiao's avatar
Fix bug  
gushiqiao committed
198
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
Dongz's avatar
Dongz committed
199
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
200
201
202
203
204

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0
root's avatar
root committed
205
        self.scheduler.noise_pred = noise_pred_cond
helloyongyang's avatar
helloyongyang committed
206

207
        if self.config["enable_cfg"]:
root's avatar
root committed
208
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
gushiqiao's avatar
Fix bug  
gushiqiao committed
209
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
root's avatar
root committed
210
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
211

root's avatar
root committed
212
213
214
215
            if self.config["feature_caching"] == "Tea":
                self.scheduler.cnt += 1
                if self.scheduler.cnt >= self.scheduler.num_steps:
                    self.scheduler.cnt = 0
helloyongyang's avatar
helloyongyang committed
216

root's avatar
root committed
217
            self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
gushiqiao's avatar
gushiqiao committed
218

root's avatar
root committed
219
220
221
            if self.config["cpu_offload"]:
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()