"megatron/training/dist_signal_handler.py" did not exist on "3aca141586a4b8cdc983c3ecf5f7baf60506c7f8"
causvid_model.py 3.05 KB
Newer Older
1
import os
PengGao's avatar
PengGao committed
2

3
import torch
PengGao's avatar
PengGao committed
4
5
from safetensors import safe_open

helloyongyang's avatar
helloyongyang committed
6
from lightx2v.common.ops.attn.radial_attn import MaskMap
PengGao's avatar
PengGao committed
7
8
9
10
11
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
    WanTransformerInferCausVid,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
12
13
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
14
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
15
16
17
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
gushiqiao's avatar
gushiqiao committed
18
from lightx2v.utils.envs import *
19
20


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
21
class WanCausVidModel(WanModel):
22
23
24
25
26
27
28
29
30
31
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        self.transformer_infer_class = WanTransformerInferCausVid
33

gushiqiao's avatar
gushiqiao committed
34
    def _load_ckpt(self, use_bf16, skip_bf16):
GoatWu's avatar
GoatWu committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        ckpt_folder = "causvid_models"
        safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
        if os.path.exists(safetensors_path):
            with safe_open(safetensors_path, framework="pt") as f:
                weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
                return weight_dict

        ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
        if os.path.exists(ckpt_path):
            weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
            weight_dict = {
                key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
            }
            return weight_dict

        return super()._load_ckpt(use_bf16, skip_bf16)
51
52
53

    @torch.no_grad()
    def infer(self, inputs, kv_start, kv_end):
54
55
56
57
58
        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            video_token_num = c * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, c)

59
60
61
62
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

wangshankun's avatar
wangshankun committed
63
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end)
64
65
66
67
68
69
70

        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
        self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

        if self.config["cpu_offload"]:
            self.pre_weight.to_cpu()
            self.post_weight.to_cpu()