model.py 20.3 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
import gc
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
2
import glob
3
4
import os

helloyongyang's avatar
helloyongyang committed
5
import torch
6
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
7
import torch.nn.functional as F
PengGao's avatar
PengGao committed
8
9
10
from loguru import logger
from safetensors import safe_open

11
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
12
13
    WanTransformerInferAdaCaching,
    WanTransformerInferCustomCaching,
Rongjin Yang's avatar
Rongjin Yang committed
14
15
    WanTransformerInferDualBlock,
    WanTransformerInferDynamicBlock,
PengGao's avatar
PengGao committed
16
    WanTransformerInferFirstBlock,
Musisoul's avatar
Musisoul committed
17
    WanTransformerInferMagCaching,
PengGao's avatar
PengGao committed
18
19
20
    WanTransformerInferTaylorCaching,
    WanTransformerInferTeaCaching,
)
21
22
23
from lightx2v.models.networks.wan.infer.offload.transformer_infer import (
    WanOffloadTransformerInfer,
)
PengGao's avatar
PengGao committed
24
25
26
27
28
29
30
31
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
    WanTransformerInfer,
)
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
32
)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
33
from lightx2v.utils.custom_compiler import CompiledMethodsMixin, compiled_method
34
from lightx2v.utils.envs import *
35
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
36

37
38
39
40
41
try:
    import gguf
except ImportError:
    gguf = None

helloyongyang's avatar
helloyongyang committed
42

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
43
class WanModel(CompiledMethodsMixin):
helloyongyang's avatar
helloyongyang committed
44
45
46
    pre_weight_class = WanPreWeights
    transformer_weight_class = WanTransformerWeights

47
    def __init__(self, model_path, config, device, model_type="wan2.1"):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
48
        super().__init__()
helloyongyang's avatar
helloyongyang committed
49
50
        self.model_path = model_path
        self.config = config
51
52
        self.cpu_offload = self.config.get("cpu_offload", False)
        self.offload_granularity = self.config.get("offload_granularity", "block")
53
        self.model_type = model_type
helloyongyang's avatar
helloyongyang committed
54
55
56
57
58

        if self.config["seq_parallel"]:
            self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
        else:
            self.seq_p_group = None
59

gushiqiao's avatar
gushiqiao committed
60
        self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
61
        self.dit_quantized = self.config.get("dit_quantized", False)
62
        if self.dit_quantized:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            assert self.config.get("dit_quant_scheme", "Default") in [
                "Default-Force-FP32",
                "fp8-vllm",
                "int8-vllm",
                "fp8-q8f",
                "int8-q8f",
                "fp8-b128-deepgemm",
                "fp8-sgl",
                "int8-sgl",
                "int8-torchao",
                "nvfp4",
                "mxfp4",
                "mxfp6-mxfp8",
                "mxfp8",
Kane's avatar
Kane committed
77
                "int8-tmo",
78
            ]
gushiqiao's avatar
gushiqiao committed
79
        self.device = device
helloyongyang's avatar
helloyongyang committed
80
81
82
83
84
85
86
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
helloyongyang's avatar
helloyongyang committed
87
88

        if self.config["feature_caching"] == "NoCaching":
89
            self.transformer_infer_class = WanTransformerInfer if not self.cpu_offload else WanOffloadTransformerInfer
helloyongyang's avatar
helloyongyang committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        elif self.config["feature_caching"] == "Tea":
            self.transformer_infer_class = WanTransformerInferTeaCaching
        elif self.config["feature_caching"] == "TaylorSeer":
            self.transformer_infer_class = WanTransformerInferTaylorCaching
        elif self.config["feature_caching"] == "Ada":
            self.transformer_infer_class = WanTransformerInferAdaCaching
        elif self.config["feature_caching"] == "Custom":
            self.transformer_infer_class = WanTransformerInferCustomCaching
        elif self.config["feature_caching"] == "FirstBlock":
            self.transformer_infer_class = WanTransformerInferFirstBlock
        elif self.config["feature_caching"] == "DualBlock":
            self.transformer_infer_class = WanTransformerInferDualBlock
        elif self.config["feature_caching"] == "DynamicBlock":
            self.transformer_infer_class = WanTransformerInferDynamicBlock
Musisoul's avatar
Musisoul committed
104
105
        elif self.config["feature_caching"] == "Mag":
            self.transformer_infer_class = WanTransformerInferMagCaching
helloyongyang's avatar
helloyongyang committed
106
        else:
helloyongyang's avatar
helloyongyang committed
107
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
108

gushiqiao's avatar
gushiqiao committed
109
110
111
112
113
114
    def _should_load_weights(self):
        """Determine if current rank should load weights from disk."""
        if self.config.get("device_mesh") is None:
            # Single GPU mode
            return True
        elif dist.is_initialized():
115
116
117
118
119
120
            if self.config.get("load_from_rank0", False):
                # Multi-GPU mode, only rank 0 loads
                if dist.get_rank() == 0:
                    logger.info(f"Loading weights from {self.model_path}")
                    return True
            else:
gushiqiao's avatar
gushiqiao committed
121
122
123
                return True
        return False

124
    def _should_init_empty_model(self):
125
        if self.config.get("lora_configs") and self.config["lora_configs"]:
126
127
128
129
130
131
132
133
134
135
136
137
            if self.model_type in ["wan2.1"]:
                return True
            if self.model_type in ["wan2.2_moe_high_noise"]:
                for lora_config in self.config["lora_configs"]:
                    if lora_config["name"] == "high_noise_model":
                        return True
            if self.model_type in ["wan2.2_moe_low_noise"]:
                for lora_config in self.config["lora_configs"]:
                    if lora_config["name"] == "low_noise_model":
                        return True
        return False

138
    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
139
140
        remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []

Kane's avatar
Kane committed
141
142
        if (self.device.type == "cuda" or self.device.type == "mlu") and dist.is_initialized():
            device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
143
144
        else:
            device = self.device
145

146
        with safe_open(file_path, framework="pt", device=str(device)) as f:
147
148
149
150
151
            return {
                key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
                for key in f.keys()
                if not any(remove_key in key for remove_key in remove_keys)
            }
helloyongyang's avatar
helloyongyang committed
152

153
    def _load_ckpt(self, unified_dtype, sensitive_layer):
154
155
156
157
158
159
160
161
162
        if self.config.get("dit_original_ckpt", None):
            safetensors_path = self.config["dit_original_ckpt"]
        else:
            safetensors_path = self.model_path

        if os.path.isdir(safetensors_path):
            safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
        else:
            safetensors_files = [safetensors_path]
163

helloyongyang's avatar
helloyongyang committed
164
165
        weight_dict = {}
        for file_path in safetensors_files:
166
            if self.config.get("adapter_model_path", None) is not None:
167
                if self.config["adapter_model_path"] == file_path:
168
                    continue
169
            logger.info(f"Loading weights from {file_path}")
170
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
helloyongyang's avatar
helloyongyang committed
171
            weight_dict.update(file_weights)
172

helloyongyang's avatar
helloyongyang committed
173
174
        return weight_dict

175
    def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
176
        remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
gushiqiao's avatar
Fix  
gushiqiao committed
177

178
179
180
181
        if self.config.get("dit_quantized_ckpt", None):
            safetensors_path = self.config["dit_quantized_ckpt"]
        else:
            safetensors_path = self.model_path
gushiqiao's avatar
Fix  
gushiqiao committed
182

183
184
185
186
        if os.path.isdir(safetensors_path):
            safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
        else:
            safetensors_files = [safetensors_path]
187
            safetensors_path = os.path.dirname(safetensors_path)
gushiqiao's avatar
Fix  
gushiqiao committed
188
189

        weight_dict = {}
190
191
192
193
        for safetensor_path in safetensors_files:
            if self.config.get("adapter_model_path", None) is not None:
                if self.config["adapter_model_path"] == safetensor_path:
                    continue
gushiqiao's avatar
Fix  
gushiqiao committed
194
195
196
            with safe_open(safetensor_path, framework="pt") as f:
                logger.info(f"Loading weights from {safetensor_path}")
                for k in f.keys():
197
198
                    if any(remove_key in k for remove_key in remove_keys):
                        continue
199
200
201
202
203
                    if f.get_tensor(k).dtype in [
                        torch.float16,
                        torch.bfloat16,
                        torch.float,
                    ]:
204
                        if unified_dtype or all(s not in k for s in sensitive_layer):
gushiqiao's avatar
gushiqiao committed
205
                            weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
206
                        else:
gushiqiao's avatar
gushiqiao committed
207
                            weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
208
                    else:
gushiqiao's avatar
gushiqiao committed
209
                        weight_dict[k] = f.get_tensor(k).to(self.device)
210

211
212
213
214
215
216
217
        if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
            calib_path = os.path.join(safetensors_path, "calib.pt")
            logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
            calib_data = torch.load(calib_path, map_location="cpu")
            for k, v in calib_data["absmax"].items():
                weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)

218
219
        return weight_dict

220
    def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):  # Need rewrite
gushiqiao's avatar
gushiqiao committed
221
        lazy_load_model_path = self.dit_quantized_ckpt
222
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
gushiqiao's avatar
gushiqiao committed
223
        pre_post_weight_dict = {}
224
225

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
gushiqiao's avatar
gushiqiao committed
226
        with safe_open(safetensor_path, framework="pt", device="cpu") as f:
227
            for k in f.keys():
228
229
230
231
232
                if f.get_tensor(k).dtype in [
                    torch.float16,
                    torch.bfloat16,
                    torch.float,
                ]:
233
                    if unified_dtype or all(s not in k for s in sensitive_layer):
gushiqiao's avatar
gushiqiao committed
234
                        pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
235
                    else:
gushiqiao's avatar
gushiqiao committed
236
                        pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
gushiqiao's avatar
Fix  
gushiqiao committed
237
                else:
gushiqiao's avatar
gushiqiao committed
238
                    pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
239

gushiqiao's avatar
gushiqiao committed
240
        return pre_post_weight_dict
241

242
243
244
245
246
247
248
249
    def _load_gguf_ckpt(self):
        gguf_path = self.dit_quantized_ckpt
        logger.info(f"Loading gguf-quant dit model from {gguf_path}")
        reader = gguf.GGUFReader(gguf_path)
        for tensor in reader.tensors:
            # TODO: implement _load_gguf_ckpt
            pass

lijiaqi2's avatar
lijiaqi2 committed
250
    def _init_weights(self, weight_dict=None):
251
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
gushiqiao's avatar
Fix  
gushiqiao committed
252
        # Some layers run with float32 to achieve high accuracy
253
        sensitive_layer = {
gushiqiao's avatar
gushiqiao committed
254
255
256
257
258
259
            "norm",
            "embedding",
            "modulation",
            "time",
            "img_emb.proj.0",
            "img_emb.proj.4",
gushiqiao's avatar
gushiqiao committed
260
261
            "before_proj",  # vace
            "after_proj",  # vace
gushiqiao's avatar
gushiqiao committed
262
        }
263

lijiaqi2's avatar
lijiaqi2 committed
264
        if weight_dict is None:
gushiqiao's avatar
gushiqiao committed
265
            is_weight_loader = self._should_load_weights()
266
            if is_weight_loader:
267
                if not self.dit_quantized:
gushiqiao's avatar
gushiqiao committed
268
269
                    # Load original weights
                    weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
270
                else:
gushiqiao's avatar
gushiqiao committed
271
                    # Load quantized weights
272
                    if not self.config.get("lazy_load", False):
gushiqiao's avatar
gushiqiao committed
273
                        weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
274
                    else:
gushiqiao's avatar
gushiqiao committed
275
                        weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
276

277
278
            if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
                weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
279

280
281
282
            if hasattr(self, "adapter_weights_dict"):
                weight_dict.update(self.adapter_weights_dict)

gushiqiao's avatar
gushiqiao committed
283
            self.original_weight_dict = weight_dict
lijiaqi2's avatar
lijiaqi2 committed
284
285
        else:
            self.original_weight_dict = weight_dict
286

gushiqiao's avatar
gushiqiao committed
287
        # Initialize weight containers
helloyongyang's avatar
helloyongyang committed
288
289
        self.pre_weight = self.pre_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
290
        if not self._should_init_empty_model():
291
            self._apply_weights()
gushiqiao's avatar
gushiqiao committed
292

293
294
295
296
297
    def _apply_weights(self, weight_dict=None):
        if weight_dict is not None:
            self.original_weight_dict = weight_dict
            del weight_dict
            gc.collect()
gushiqiao's avatar
gushiqiao committed
298
        # Load weights into containers
299
        self.pre_weight.load(self.original_weight_dict)
gushiqiao's avatar
gushiqiao committed
300
        self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
301

gushiqiao's avatar
gushiqiao committed
302
303
304
305
        del self.original_weight_dict
        torch.cuda.empty_cache()
        gc.collect()

306
307
    def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
        logger.info("Loading distributed weights")
gushiqiao's avatar
gushiqiao committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        global_src_rank = 0
        target_device = "cpu" if self.cpu_offload else "cuda"

        if is_weight_loader:
            meta_dict = {}
            for key, tensor in weight_dict.items():
                meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

            obj_list = [meta_dict]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]
        else:
            obj_list = [None]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]

        distributed_weight_dict = {}
        for key, meta in synced_meta_dict.items():
            distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)

        if target_device == "cuda":
            dist.barrier(device_ids=[torch.cuda.current_device()])

        for key in sorted(synced_meta_dict.keys()):
            if is_weight_loader:
                distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)

gushiqiao's avatar
gushiqiao committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            if target_device == "cpu":
                if is_weight_loader:
                    gpu_tensor = distributed_weight_dict[key].cuda()
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()
                else:
                    gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()

                if distributed_weight_dict[key].is_pinned():
                    distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
            else:
                dist.broadcast(distributed_weight_dict[key], src=global_src_rank)

        if target_device == "cuda":
            torch.cuda.synchronize()
        else:
            for tensor in distributed_weight_dict.values():
                if tensor.is_pinned():
                    tensor.copy_(tensor, non_blocking=False)
gushiqiao's avatar
gushiqiao committed
360
361

        logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
362

gushiqiao's avatar
gushiqiao committed
363
364
        return distributed_weight_dict

helloyongyang's avatar
helloyongyang committed
365
366
367
    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
helloyongyang's avatar
helloyongyang committed
368
        self.transformer_infer = self.transformer_infer_class(self.config)
369
370
        if hasattr(self.transformer_infer, "offload_manager"):
            self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
helloyongyang's avatar
helloyongyang committed
371
372
373

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
374
375
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
376
377
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
378
379
380
381
382
383
384
385
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
386
387
    @torch.no_grad()
    def infer(self, inputs):
388
        if self.cpu_offload:
389
            if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
390
391
392
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
gushiqiao's avatar
gushiqiao committed
393
                self.transformer_weights.non_block_weights_to_cuda()
394

395
        if self.config["enable_cfg"]:
helloyongyang's avatar
helloyongyang committed
396
397
398
399
400
401
402
            if self.config["cfg_parallel"]:
                # ==================== CFG Parallel Processing ====================
                cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
                assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
                cfg_p_rank = dist.get_rank(cfg_p_group)

                if cfg_p_rank == 0:
helloyongyang's avatar
helloyongyang committed
403
                    noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
helloyongyang's avatar
helloyongyang committed
404
                else:
helloyongyang's avatar
helloyongyang committed
405
                    noise_pred = self._infer_cond_uncond(inputs, infer_condition=False)
helloyongyang's avatar
helloyongyang committed
406

helloyongyang's avatar
helloyongyang committed
407
408
409
410
411
412
                noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
                dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
                noise_pred_cond = noise_pred_list[0]  # cfg_p_rank == 0
                noise_pred_uncond = noise_pred_list[1]  # cfg_p_rank == 1
            else:
                # ==================== CFG Processing ====================
helloyongyang's avatar
helloyongyang committed
413
414
                noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
                noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
gushiqiao's avatar
gushiqiao committed
415

helloyongyang's avatar
helloyongyang committed
416
417
418
            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
        else:
            # ==================== No CFG ====================
helloyongyang's avatar
helloyongyang committed
419
            self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
420
421

        if self.cpu_offload:
422
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
423
424
                self.to_cpu()
            elif self.offload_granularity != "model":
root's avatar
root committed
425
                self.pre_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
426
                self.transformer_weights.non_block_weights_to_cpu()
gushiqiao's avatar
gushiqiao committed
427

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
428
    @compiled_method()
429
    @torch.no_grad()
helloyongyang's avatar
helloyongyang committed
430
431
432
433
    def _infer_cond_uncond(self, inputs, infer_condition=True):
        self.scheduler.infer_condition = infer_condition

        pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
helloyongyang's avatar
helloyongyang committed
434
435
436
437
438
439
440
441
442

        if self.config["seq_parallel"]:
            pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)

        x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)

        if self.config["seq_parallel"]:
            x = self._seq_parallel_post_process(x)

gushiqiao's avatar
gushiqiao committed
443
        noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
helloyongyang's avatar
helloyongyang committed
444
445
446
447
448
449
450
451
452

        if self.clean_cuda_cache:
            del x, pre_infer_out
            torch.cuda.empty_cache()

        return noise_pred

    @torch.no_grad()
    def _seq_parallel_pre_process(self, pre_infer_out):
helloyongyang's avatar
helloyongyang committed
453
        x = pre_infer_out.x
helloyongyang's avatar
helloyongyang committed
454
455
456
457
458
        world_size = dist.get_world_size(self.seq_p_group)
        cur_rank = dist.get_rank(self.seq_p_group)

        padding_size = (world_size - (x.shape[0] % world_size)) % world_size
        if padding_size > 0:
helloyongyang's avatar
helloyongyang committed
459
            x = F.pad(x, (0, 0, 0, padding_size))
helloyongyang's avatar
helloyongyang committed
460

helloyongyang's avatar
helloyongyang committed
461
        pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
helloyongyang's avatar
helloyongyang committed
462

463
        if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]:
helloyongyang's avatar
helloyongyang committed
464
465
466
467
468
469
470
            embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0

            padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
            if padding_size > 0:
                embed = F.pad(embed, (0, 0, 0, padding_size))
                embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))

helloyongyang's avatar
helloyongyang committed
471
472
            pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
            pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
helloyongyang's avatar
helloyongyang committed
473
474
475
476
477
478
479
480
481

        return pre_infer_out

    @torch.no_grad()
    def _seq_parallel_post_process(self, x):
        world_size = dist.get_world_size(self.seq_p_group)
        gathered_x = [torch.empty_like(x) for _ in range(world_size)]
        dist.all_gather(gathered_x, x, group=self.seq_p_group)
        combined_output = torch.cat(gathered_x, dim=0)
helloyongyang's avatar
helloyongyang committed
482
        return combined_output