model.py 16.1 KB
Newer Older
1
import json
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
import torch
5
import torch.distributed as dist
PengGao's avatar
PengGao committed
6
7
8
from loguru import logger
from safetensors import safe_open

helloyongyang's avatar
helloyongyang committed
9
from lightx2v.common.ops.attn import MaskMap
PengGao's avatar
PengGao committed
10
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
11
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
12
13
    WanTransformerInferAdaCaching,
    WanTransformerInferCustomCaching,
Rongjin Yang's avatar
Rongjin Yang committed
14
15
    WanTransformerInferDualBlock,
    WanTransformerInferDynamicBlock,
PengGao's avatar
PengGao committed
16
17
18
19
20
21
22
23
24
25
26
27
28
    WanTransformerInferFirstBlock,
    WanTransformerInferTaylorCaching,
    WanTransformerInferTeaCaching,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
    WanTransformerInfer,
)
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
29
)
30
from lightx2v.utils.envs import *
31
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
32

33
34
35
36
37
try:
    import gguf
except ImportError:
    gguf = None

helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43

class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

44
    def __init__(self, model_path, config, device, seq_p_group=None):
helloyongyang's avatar
helloyongyang committed
45
46
        self.model_path = model_path
        self.config = config
47
48
        self.cpu_offload = self.config.get("cpu_offload", False)
        self.offload_granularity = self.config.get("offload_granularity", "block")
49
        self.seq_p_group = seq_p_group
50

gushiqiao's avatar
gushiqiao committed
51
        self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
52
        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
53

gushiqiao's avatar
gushiqiao committed
54
55
        if self.dit_quantized:
            dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
gushiqiao's avatar
gushiqiao committed
56
57
            if self.config.model_cls == "wan2.1_distill":
                dit_quant_scheme = "distill_" + dit_quant_scheme
58
59
60
61
            if dit_quant_scheme == "gguf":
                self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
                self.config.use_gguf = True
            else:
helloyongyang's avatar
helloyongyang committed
62
                self.dit_quantized_ckpt = find_hf_model_path(config, self.model_path, "dit_quantized_ckpt", subdir=dit_quant_scheme)
gushiqiao's avatar
Fix bug  
gushiqiao committed
63
64
65
66
67
            quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
            if os.path.exists(quant_config_path):
                with open(quant_config_path, "r") as f:
                    quant_model_config = json.load(f)
                self.config.update(quant_model_config)
gushiqiao's avatar
gushiqiao committed
68
69
        else:
            self.dit_quantized_ckpt = None
70
71
            assert not self.config.get("lazy_load", False)

gushiqiao's avatar
gushiqiao committed
72
        self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
gushiqiao's avatar
gushiqiao committed
73

74
75
76
77
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
        if self.dit_quantized:
            assert self.weight_auto_quant or self.dit_quantized_ckpt is not None

gushiqiao's avatar
gushiqiao committed
78
        self.device = device
helloyongyang's avatar
helloyongyang committed
79
80
81
82
83
84
85
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
wangshankun's avatar
wangshankun committed
86
        if self.seq_p_group is not None:
helloyongyang's avatar
helloyongyang committed
87
            self.transformer_infer_class = WanTransformerDistInfer
helloyongyang's avatar
helloyongyang committed
88
        else:
helloyongyang's avatar
helloyongyang committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            if self.config["feature_caching"] == "NoCaching":
                self.transformer_infer_class = WanTransformerInfer
            elif self.config["feature_caching"] == "Tea":
                self.transformer_infer_class = WanTransformerInferTeaCaching
            elif self.config["feature_caching"] == "TaylorSeer":
                self.transformer_infer_class = WanTransformerInferTaylorCaching
            elif self.config["feature_caching"] == "Ada":
                self.transformer_infer_class = WanTransformerInferAdaCaching
            elif self.config["feature_caching"] == "Custom":
                self.transformer_infer_class = WanTransformerInferCustomCaching
            elif self.config["feature_caching"] == "FirstBlock":
                self.transformer_infer_class = WanTransformerInferFirstBlock
            elif self.config["feature_caching"] == "DualBlock":
                self.transformer_infer_class = WanTransformerInferDualBlock
            elif self.config["feature_caching"] == "DynamicBlock":
                self.transformer_infer_class = WanTransformerInferDynamicBlock
            else:
                raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
107

108
    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
helloyongyang's avatar
helloyongyang committed
109
        with safe_open(file_path, framework="pt") as f:
110
111
112
113
            return {
                key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in f.keys()
            }
helloyongyang's avatar
helloyongyang committed
114

115
    def _load_ckpt(self, unified_dtype, sensitive_layer):
helloyongyang's avatar
helloyongyang committed
116
        safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
117
        safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
helloyongyang's avatar
helloyongyang committed
118
119
        weight_dict = {}
        for file_path in safetensors_files:
120
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
helloyongyang's avatar
helloyongyang committed
121
122
123
            weight_dict.update(file_weights)
        return weight_dict

124
    def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
125
        ckpt_path = self.dit_quantized_ckpt
126
        logger.info(f"Loading quant dit model from {ckpt_path}")
127

gushiqiao's avatar
Fix  
gushiqiao committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
        if not index_files:
            raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")

        index_path = os.path.join(ckpt_path, index_files[0])
        logger.info(f" Using safetensors index: {index_path}")

        with open(index_path, "r") as f:
            index_data = json.load(f)

        weight_dict = {}
        for filename in set(index_data["weight_map"].values()):
            safetensor_path = os.path.join(ckpt_path, filename)
            with safe_open(safetensor_path, framework="pt") as f:
                logger.info(f"Loading weights from {safetensor_path}")
                for k in f.keys():
144
                    if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
145
146
                        if unified_dtype or all(s not in k for s in sensitive_layer):
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
147
                        else:
148
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
149
150
                    else:
                        weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
151

152
153
        return weight_dict

154
    def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
155
        lazy_load_model_path = self.dit_quantized_ckpt
156
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
gushiqiao's avatar
gushiqiao committed
157
        pre_post_weight_dict = {}
158
159

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
gushiqiao's avatar
gushiqiao committed
160
        with safe_open(safetensor_path, framework="pt", device="cpu") as f:
161
            for k in f.keys():
162
                if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
163
164
                    if unified_dtype or all(s not in k for s in sensitive_layer):
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
165
                    else:
166
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
167
168
                else:
                    pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
169

gushiqiao's avatar
gushiqiao committed
170
        return pre_post_weight_dict
171

172
173
174
175
176
177
178
179
    def _load_gguf_ckpt(self):
        gguf_path = self.dit_quantized_ckpt
        logger.info(f"Loading gguf-quant dit model from {gguf_path}")
        reader = gguf.GGUFReader(gguf_path)
        for tensor in reader.tensors:
            # TODO: implement _load_gguf_ckpt
            pass

lijiaqi2's avatar
lijiaqi2 committed
180
    def _init_weights(self, weight_dict=None):
181
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
gushiqiao's avatar
Fix  
gushiqiao committed
182
        # Some layers run with float32 to achieve high accuracy
183
        sensitive_layer = {
gushiqiao's avatar
gushiqiao committed
184
185
186
187
188
189
190
            "norm",
            "embedding",
            "modulation",
            "time",
            "img_emb.proj.0",
            "img_emb.proj.4",
        }
191

lijiaqi2's avatar
lijiaqi2 committed
192
        if weight_dict is None:
193
194
195
196
197
198
199
200
201
202
203
204
205
            is_weight_loader = False
            if self.seq_p_group is None:
                is_weight_loader = True
                logger.info(f"Loading original dit model from {self.model_path}")
            elif dist.is_initialized():
                if dist.get_rank(group=self.seq_p_group) == 0:
                    is_weight_loader = True
                    logger.info(f"Loading original dit model from {self.model_path}")

            cpu_weight_dict = {}
            if is_weight_loader:
                if not self.dit_quantized or self.weight_auto_quant:
                    cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
206
                else:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                    if not self.config.get("lazy_load", False):
                        cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
                    else:
                        cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)

            if self.seq_p_group is None:  # 单卡模式
                self.original_weight_dict = {}
                for key, tensor in cpu_weight_dict.items():
                    self.original_weight_dict[key] = tensor.to("cuda", non_blocking=True)
            else:
                seq_p_group = self.seq_p_group
                global_src_rank = dist.get_process_group_ranks(seq_p_group)[0]

                meta_dict = {}
                if is_weight_loader:
                    for key, tensor in cpu_weight_dict.items():
                        meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

                obj_list = [meta_dict] if is_weight_loader else [None]
                dist.broadcast_object_list(obj_list, src=global_src_rank, group=seq_p_group)
                synced_meta_dict = obj_list[0]

                self.original_weight_dict = {}
                for key, meta in synced_meta_dict.items():
                    self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")

                dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
                for key in sorted(synced_meta_dict.keys()):
                    tensor_to_broadcast = self.original_weight_dict[key]
                    if is_weight_loader:
                        tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)

                    dist.broadcast(tensor_to_broadcast, src=global_src_rank, group=seq_p_group)

            if is_weight_loader:
                del cpu_weight_dict
lijiaqi2's avatar
lijiaqi2 committed
243
244
        else:
            self.original_weight_dict = weight_dict
245

helloyongyang's avatar
helloyongyang committed
246
247
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
248
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
249
250
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
251
252
        self.pre_weight.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)
gushiqiao's avatar
gushiqiao committed
253
        self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
254
255
256
257

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
wangshankun's avatar
wangshankun committed
258
259
260
261
262
263

        if self.seq_p_group is not None:
            self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
        else:
            self.transformer_infer = self.transformer_infer_class(self.config)

264
        if self.config["cfg_parallel"]:
helloyongyang's avatar
helloyongyang committed
265
            self.infer_func = self.infer_with_cfg_parallel
266
        else:
helloyongyang's avatar
helloyongyang committed
267
            self.infer_func = self.infer_wo_cfg_parallel
helloyongyang's avatar
helloyongyang committed
268
269
270

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
271
272
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
273
274
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
275
276
277
278
279
280
281
282
283
284
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
285
286
287
288
    @torch.no_grad()
    def infer(self, inputs):
        return self.infer_func(inputs)

helloyongyang's avatar
helloyongyang committed
289
    @torch.no_grad()
290
    def infer_wo_cfg_parallel(self, inputs):
291
292
293
294
295
296
297
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

298
299
300
301
302
        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            video_token_num = c * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, c)

303
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
gushiqiao's avatar
Fix bug  
gushiqiao committed
304
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
Dongz's avatar
Dongz committed
305
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
306

root's avatar
root committed
307
        self.scheduler.noise_pred = noise_pred_cond
helloyongyang's avatar
helloyongyang committed
308

gushiqiao's avatar
gushiqiao committed
309
310
311
312
        if self.clean_cuda_cache:
            del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
            torch.cuda.empty_cache()

313
        if self.config["enable_cfg"]:
root's avatar
root committed
314
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
gushiqiao's avatar
Fix bug  
gushiqiao committed
315
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
root's avatar
root committed
316
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
317

helloyongyang's avatar
helloyongyang committed
318
            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
gushiqiao's avatar
gushiqiao committed
319

320
321
322
323
324
325
326
327
            if self.clean_cuda_cache:
                del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
                torch.cuda.empty_cache()

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
                self.to_cpu()
            elif self.offload_granularity != "model":
root's avatar
root committed
328
329
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    @torch.no_grad()
    def infer_with_cfg_parallel(self, inputs):
        assert self.config["enable_cfg"], "enable_cfg must be True"
        cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
        assert dist.get_world_size(cfg_p_group) == 2, f"cfg_p_world_size must be equal to 2"
        cfg_p_rank = dist.get_rank(cfg_p_group)

        if cfg_p_rank == 0:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
        else:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

        noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
        dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)

        noise_pred_cond = noise_pred_list[0]  # cfg_p_rank == 0
        noise_pred_uncond = noise_pred_list[1]  # cfg_p_rank == 1
        self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)