model.py 16.1 KB
Newer Older
1
import json
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
import torch
5
import torch.distributed as dist
PengGao's avatar
PengGao committed
6
7
8
from loguru import logger
from safetensors import safe_open

helloyongyang's avatar
helloyongyang committed
9
from lightx2v.common.ops.attn import MaskMap
PengGao's avatar
PengGao committed
10
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
11
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
12
13
    WanTransformerInferAdaCaching,
    WanTransformerInferCustomCaching,
Rongjin Yang's avatar
Rongjin Yang committed
14
15
    WanTransformerInferDualBlock,
    WanTransformerInferDynamicBlock,
PengGao's avatar
PengGao committed
16
17
18
19
20
21
22
23
24
25
26
27
28
    WanTransformerInferFirstBlock,
    WanTransformerInferTaylorCaching,
    WanTransformerInferTeaCaching,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
    WanTransformerInfer,
)
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
29
)
30
from lightx2v.utils.envs import *
31
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
32

33
34
35
36
37
try:
    import gguf
except ImportError:
    gguf = None

helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43

class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

44
    def __init__(self, model_path, config, device, seq_p_group=None):
helloyongyang's avatar
helloyongyang committed
45
46
        self.model_path = model_path
        self.config = config
47
48
        self.cpu_offload = self.config.get("cpu_offload", False)
        self.offload_granularity = self.config.get("offload_granularity", "block")
49
        self.seq_p_group = seq_p_group
50

gushiqiao's avatar
gushiqiao committed
51
        self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
52
        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
53

gushiqiao's avatar
gushiqiao committed
54
55
        if self.dit_quantized:
            dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
gushiqiao's avatar
gushiqiao committed
56
57
            if self.config.model_cls == "wan2.1_distill":
                dit_quant_scheme = "distill_" + dit_quant_scheme
58
59
60
61
            if dit_quant_scheme == "gguf":
                self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
                self.config.use_gguf = True
            else:
helloyongyang's avatar
helloyongyang committed
62
                self.dit_quantized_ckpt = find_hf_model_path(config, self.model_path, "dit_quantized_ckpt", subdir=dit_quant_scheme)
gushiqiao's avatar
Fix bug  
gushiqiao committed
63
64
65
66
67
            quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
            if os.path.exists(quant_config_path):
                with open(quant_config_path, "r") as f:
                    quant_model_config = json.load(f)
                self.config.update(quant_model_config)
gushiqiao's avatar
gushiqiao committed
68
69
        else:
            self.dit_quantized_ckpt = None
70
71
            assert not self.config.get("lazy_load", False)

gushiqiao's avatar
gushiqiao committed
72
        self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
gushiqiao's avatar
gushiqiao committed
73

74
75
76
77
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
        if self.dit_quantized:
            assert self.weight_auto_quant or self.dit_quantized_ckpt is not None

gushiqiao's avatar
gushiqiao committed
78
        self.device = device
helloyongyang's avatar
helloyongyang committed
79
80
81
82
83
84
85
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
wangshankun's avatar
wangshankun committed
86
        if self.seq_p_group is not None:
helloyongyang's avatar
helloyongyang committed
87
            self.transformer_infer_class = WanTransformerDistInfer
helloyongyang's avatar
helloyongyang committed
88
        else:
helloyongyang's avatar
helloyongyang committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            if self.config["feature_caching"] == "NoCaching":
                self.transformer_infer_class = WanTransformerInfer
            elif self.config["feature_caching"] == "Tea":
                self.transformer_infer_class = WanTransformerInferTeaCaching
            elif self.config["feature_caching"] == "TaylorSeer":
                self.transformer_infer_class = WanTransformerInferTaylorCaching
            elif self.config["feature_caching"] == "Ada":
                self.transformer_infer_class = WanTransformerInferAdaCaching
            elif self.config["feature_caching"] == "Custom":
                self.transformer_infer_class = WanTransformerInferCustomCaching
            elif self.config["feature_caching"] == "FirstBlock":
                self.transformer_infer_class = WanTransformerInferFirstBlock
            elif self.config["feature_caching"] == "DualBlock":
                self.transformer_infer_class = WanTransformerInferDualBlock
            elif self.config["feature_caching"] == "DynamicBlock":
                self.transformer_infer_class = WanTransformerInferDynamicBlock
            else:
                raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
107

108
    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
helloyongyang's avatar
helloyongyang committed
109
        with safe_open(file_path, framework="pt") as f:
110
111
112
113
            return {
                key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in f.keys()
            }
helloyongyang's avatar
helloyongyang committed
114

115
    def _load_ckpt(self, unified_dtype, sensitive_layer):
helloyongyang's avatar
helloyongyang committed
116
        safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
117
        safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
helloyongyang's avatar
helloyongyang committed
118
119
        weight_dict = {}
        for file_path in safetensors_files:
120
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
helloyongyang's avatar
helloyongyang committed
121
122
123
            weight_dict.update(file_weights)
        return weight_dict

124
    def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
125
        ckpt_path = self.dit_quantized_ckpt
126
        logger.info(f"Loading quant dit model from {ckpt_path}")
127

gushiqiao's avatar
Fix  
gushiqiao committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
        if not index_files:
            raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")

        index_path = os.path.join(ckpt_path, index_files[0])
        logger.info(f" Using safetensors index: {index_path}")

        with open(index_path, "r") as f:
            index_data = json.load(f)

        weight_dict = {}
        for filename in set(index_data["weight_map"].values()):
            safetensor_path = os.path.join(ckpt_path, filename)
            with safe_open(safetensor_path, framework="pt") as f:
                logger.info(f"Loading weights from {safetensor_path}")
                for k in f.keys():
144
                    if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
145
146
                        if unified_dtype or all(s not in k for s in sensitive_layer):
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
147
                        else:
148
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
149
150
                    else:
                        weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
151

152
153
        return weight_dict

154
    def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
155
        lazy_load_model_path = self.dit_quantized_ckpt
156
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
gushiqiao's avatar
gushiqiao committed
157
        pre_post_weight_dict = {}
158
159

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
gushiqiao's avatar
gushiqiao committed
160
        with safe_open(safetensor_path, framework="pt", device="cpu") as f:
161
            for k in f.keys():
162
                if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
163
164
                    if unified_dtype or all(s not in k for s in sensitive_layer):
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
165
                    else:
166
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
167
168
                else:
                    pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
169

gushiqiao's avatar
gushiqiao committed
170
        return pre_post_weight_dict
171

172
173
174
175
176
177
178
179
    def _load_gguf_ckpt(self):
        gguf_path = self.dit_quantized_ckpt
        logger.info(f"Loading gguf-quant dit model from {gguf_path}")
        reader = gguf.GGUFReader(gguf_path)
        for tensor in reader.tensors:
            # TODO: implement _load_gguf_ckpt
            pass

lijiaqi2's avatar
lijiaqi2 committed
180
    def _init_weights(self, weight_dict=None):
181
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
gushiqiao's avatar
Fix  
gushiqiao committed
182
        # Some layers run with float32 to achieve high accuracy
183
        sensitive_layer = {
gushiqiao's avatar
gushiqiao committed
184
185
186
187
188
189
190
            "norm",
            "embedding",
            "modulation",
            "time",
            "img_emb.proj.0",
            "img_emb.proj.4",
        }
191

lijiaqi2's avatar
lijiaqi2 committed
192
        if weight_dict is None:
193
            is_weight_loader = False
194
            if self.config.get("device_mesh") is None:
195
196
197
                is_weight_loader = True
                logger.info(f"Loading original dit model from {self.model_path}")
            elif dist.is_initialized():
198
                if dist.get_rank() == 0:
199
200
201
202
203
204
205
                    is_weight_loader = True
                    logger.info(f"Loading original dit model from {self.model_path}")

            cpu_weight_dict = {}
            if is_weight_loader:
                if not self.dit_quantized or self.weight_auto_quant:
                    cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
206
                else:
207
208
209
210
211
                    if not self.config.get("lazy_load", False):
                        cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
                    else:
                        cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)

212
            if self.config.get("device_mesh") is None:  # 单卡模式
213
                self.original_weight_dict = {}
214
                init_device = "cpu" if self.cpu_offload else "cuda"
215
                for key, tensor in cpu_weight_dict.items():
216
                    self.original_weight_dict[key] = tensor.to(init_device, non_blocking=True)
217
            else:
218
                global_src_rank = 0
219
220
221
222
223
224
225

                meta_dict = {}
                if is_weight_loader:
                    for key, tensor in cpu_weight_dict.items():
                        meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

                obj_list = [meta_dict] if is_weight_loader else [None]
226
                dist.broadcast_object_list(obj_list, src=global_src_rank)
227
228
229
230
231
232
                synced_meta_dict = obj_list[0]

                self.original_weight_dict = {}
                for key, meta in synced_meta_dict.items():
                    self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")

233
                dist.barrier(device_ids=[torch.cuda.current_device()])
234
235
236
237
238
                for key in sorted(synced_meta_dict.keys()):
                    tensor_to_broadcast = self.original_weight_dict[key]
                    if is_weight_loader:
                        tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)

239
                    dist.broadcast(tensor_to_broadcast, src=global_src_rank)
240
241
242

            if is_weight_loader:
                del cpu_weight_dict
lijiaqi2's avatar
lijiaqi2 committed
243
244
        else:
            self.original_weight_dict = weight_dict
245

helloyongyang's avatar
helloyongyang committed
246
247
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
248
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
249
250
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
251
252
        self.pre_weight.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)
gushiqiao's avatar
gushiqiao committed
253
        self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
254

255
256
257
        del self.original_weight_dict
        torch.cuda.empty_cache()

helloyongyang's avatar
helloyongyang committed
258
259
260
    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
wangshankun's avatar
wangshankun committed
261
262
263
264
265
266

        if self.seq_p_group is not None:
            self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
        else:
            self.transformer_infer = self.transformer_infer_class(self.config)

267
        if self.config["cfg_parallel"]:
helloyongyang's avatar
helloyongyang committed
268
            self.infer_func = self.infer_with_cfg_parallel
269
        else:
helloyongyang's avatar
helloyongyang committed
270
            self.infer_func = self.infer_wo_cfg_parallel
helloyongyang's avatar
helloyongyang committed
271
272
273

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
274
275
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
276
277
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
278
279
280
281
282
283
284
285
286
287
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
288
289
290
291
    @torch.no_grad()
    def infer(self, inputs):
        return self.infer_func(inputs)

helloyongyang's avatar
helloyongyang committed
292
    @torch.no_grad()
293
    def infer_wo_cfg_parallel(self, inputs):
294
295
296
297
298
299
300
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

301
302
303
304
305
        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            video_token_num = c * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, c)

306
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
gushiqiao's avatar
Fix bug  
gushiqiao committed
307
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
Dongz's avatar
Dongz committed
308
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
309

root's avatar
root committed
310
        self.scheduler.noise_pred = noise_pred_cond
helloyongyang's avatar
helloyongyang committed
311

gushiqiao's avatar
gushiqiao committed
312
313
314
315
        if self.clean_cuda_cache:
            del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
            torch.cuda.empty_cache()

316
        if self.config["enable_cfg"]:
root's avatar
root committed
317
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
gushiqiao's avatar
Fix bug  
gushiqiao committed
318
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
root's avatar
root committed
319
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
320

helloyongyang's avatar
helloyongyang committed
321
            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
gushiqiao's avatar
gushiqiao committed
322

323
324
325
326
327
328
329
330
            if self.clean_cuda_cache:
                del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
                torch.cuda.empty_cache()

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
                self.to_cpu()
            elif self.offload_granularity != "model":
root's avatar
root committed
331
332
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    @torch.no_grad()
    def infer_with_cfg_parallel(self, inputs):
        assert self.config["enable_cfg"], "enable_cfg must be True"
        cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
        assert dist.get_world_size(cfg_p_group) == 2, f"cfg_p_world_size must be equal to 2"
        cfg_p_rank = dist.get_rank(cfg_p_group)

        if cfg_p_rank == 0:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
        else:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

        noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
        dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)

        noise_pred_cond = noise_pred_list[0]  # cfg_p_rank == 0
        noise_pred_uncond = noise_pred_list[1]  # cfg_p_rank == 1
        self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)