causvid_model.py 2.4 KB
Newer Older
1
import os
PengGao's avatar
PengGao committed
2

3
import torch
PengGao's avatar
PengGao committed
4
5
6
7
8
9

from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
    WanTransformerInferCausVid,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
10
11
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
12
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
13
14
15
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.utils.envs import *
helloyongyang's avatar
fix ci  
helloyongyang committed
17
from lightx2v.utils.utils import find_torch_model_path
18
19


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
20
class WanCausVidModel(WanModel):
21
22
23
24
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

helloyongyang's avatar
helloyongyang committed
25
26
    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)
27
28
29
30

    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        self.transformer_infer_class = WanTransformerInferCausVid
32

33
    def _load_ckpt(self, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
34
        ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt")
GoatWu's avatar
GoatWu committed
35
36
37
        if os.path.exists(ckpt_path):
            weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
            weight_dict = {
gushiqiao's avatar
gushiqiao committed
38
39
                key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in weight_dict.keys()
GoatWu's avatar
GoatWu committed
40
41
42
            }
            return weight_dict

43
        return super()._load_ckpt(unified_dtype, sensitive_layer)
44
45
46
47
48

    @torch.no_grad()
    def infer(self, inputs, kv_start, kv_end):
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
gushiqiao's avatar
gushiqiao committed
49
            self.transformer_weights.post_weights_to_cuda()
50

helloyongyang's avatar
helloyongyang committed
51
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end)
52
53

        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
gushiqiao's avatar
gushiqiao committed
54
        self.scheduler.noise_pred = self.post_infer.infer(x, embed, grid_sizes)[0]
55
56
57

        if self.config["cpu_offload"]:
            self.pre_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
58
            self.transformer_weights.post_weights_to_cpu()