causvid_model.py 2.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import os
import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
13
14
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
    WanTransformerInferCausVid,
15
16
17
18
19
20
21
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
22
class WanCausVidModel(WanModel):
23
24
25
26
27
28
29
30
31
32
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
33
        self.transformer_infer_class = WanTransformerInferCausVid
34
35
36

    def _load_ckpt(self):
        use_bfloat16 = self.config.get("use_bfloat16", True)
wangshankun's avatar
wangshankun committed
37
38
39
40
41
42
        ckpt_path = os.path.join(self.model_path, "causal_model.pt")
        if not os.path.exists(ckpt_path):
            # 文件不存在,调用父类的 _load_ckpt 方法
            return super()._load_ckpt()

        weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
43
44
45
46
47
48
49
50
51
52
53
54
55

        dtype = torch.bfloat16 if use_bfloat16 else None
        for key, value in weight_dict.items():
            weight_dict[key] = value.to(device=self.device, dtype=dtype)

        return weight_dict

    @torch.no_grad()
    def infer(self, inputs, kv_start, kv_end):
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

wangshankun's avatar
wangshankun committed
56
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end)
57
58
59
60
61
62
63

        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
        self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

        if self.config["cpu_offload"]:
            self.pre_weight.to_cpu()
            self.post_weight.to_cpu()