model.py 5.28 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import torch
import time
import glob
from lightx2v.text2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.text2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.text2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)
from lightx2v.text2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.text2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
    WanTransformerInfer,
)
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching
from safetensors import safe_open
Xinchi Huang's avatar
Xinchi Huang committed
17
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_wan
helloyongyang's avatar
helloyongyang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config):
        self.model_path = model_path
        self.config = config
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

Dongz's avatar
Dongz committed
32
        if config["parallel_attn"]:
Xinchi Huang's avatar
Xinchi Huang committed
33
34
            parallelize_wan(self)

Dongz's avatar
Dongz committed
35
        if self.config["cpu_offload"]:
TorynCurtis's avatar
TorynCurtis committed
36
37
            self.to_cpu()

helloyongyang's avatar
helloyongyang committed
38
39
40
41
42
43
44
45
    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WanTransformerInfer
        elif self.config["feature_caching"] == "Tea":
            self.transformer_infer_class = WanTransformerInferFeatureCaching
        else:
Dongz's avatar
Dongz committed
46
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
47
48
49

    def _load_safetensor_to_dict(self, file_path):
        with safe_open(file_path, framework="pt") as f:
Dongz's avatar
Dongz committed
50
            tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
helloyongyang's avatar
helloyongyang committed
51
52
53
54
55
56
57
        return tensor_dict

    def _load_ckpt(self):
        safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)

        if not safetensors_files:
Dongz's avatar
Dongz committed
58
            raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
helloyongyang's avatar
helloyongyang committed
59
60
61
62
63
64
65
66
67
68
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path)
            weight_dict.update(file_weights)
        return weight_dict

    def _init_weights(self):
        weight_dict = self._load_ckpt()
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
69
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
        self.pre_weight.load_weights(weight_dict)
        self.post_weight.load_weights(weight_dict)
        self.transformer_weights.load_weights(weight_dict)

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        self.transformer_infer = self.transformer_infer_class(self.config)

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
85
86
87
88
89
90
91
92
93
94
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
95
96
97
98
99
100
101
102
103
104
105
106
107
    @torch.no_grad()
    def infer(self, text_encoders_output, image_encoder_output, args):
        timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])

        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
            self.pre_weight,
            [self.scheduler.latents],
            timestep,
            text_encoders_output["context"],
            self.scheduler.seq_len,
            image_encoder_output["clip_encoder_out"],
            [image_encoder_output["vae_encode_out"]],
        )
Dongz's avatar
Dongz committed
108
109
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0

        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
            self.pre_weight,
            [self.scheduler.latents],
            timestep,
            text_encoders_output["context_null"],
            self.scheduler.seq_len,
            image_encoder_output["clip_encoder_out"],
            [image_encoder_output["vae_encode_out"]],
        )
Dongz's avatar
Dongz committed
125
126
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
127
128
129
130
131
132

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0

Dongz's avatar
Dongz committed
133
        self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)