model.py 9.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import os
import torch
import glob
4
import json
5
6
7
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
helloyongyang's avatar
helloyongyang committed
8
9
    WanTransformerWeights,
)
10
11
12
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
helloyongyang's avatar
helloyongyang committed
13
14
    WanTransformerInfer,
)
15
16
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
    WanTransformerInferTeaCaching,
17
18
19
    WanTransformerInferTaylorCaching,
    WanTransformerInferAdaCaching,
    WanTransformerInferCustomCaching,
20
)
helloyongyang's avatar
helloyongyang committed
21
from safetensors import safe_open
Xinchi Huang's avatar
Xinchi Huang committed
22
23
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
24
25
from lightx2v.utils.envs import *
from loguru import logger
helloyongyang's avatar
helloyongyang committed
26
27
28
29
30
31
32


class WanModel:
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

gushiqiao's avatar
gushiqiao committed
33
    def __init__(self, model_path, config, device):
helloyongyang's avatar
helloyongyang committed
34
35
        self.model_path = model_path
        self.config = config
gushiqiao's avatar
gushiqiao committed
36
        self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
37
38

        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
gushiqiao's avatar
gushiqiao committed
39
40
41
42
43
        if self.dit_quantized:
            dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
            self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme))
        else:
            self.dit_quantized_ckpt = None
44
45
46
47
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
        if self.dit_quantized:
            assert self.weight_auto_quant or self.dit_quantized_ckpt is not None

gushiqiao's avatar
gushiqiao committed
48
        self.device = device
helloyongyang's avatar
helloyongyang committed
49
50
51
52
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

Xinchi Huang's avatar
Xinchi Huang committed
53
54
55
56
57
58
59
        if config["parallel_attn_type"]:
            if config["parallel_attn_type"] == "ulysses":
                ulysses_dist_wrap.parallelize_wan(self)
            elif config["parallel_attn_type"] == "ring":
                ring_dist_wrap.parallelize_wan(self)
            else:
                raise Exception(f"Unsuppotred parallel_attn_type")
Xinchi Huang's avatar
Xinchi Huang committed
60

helloyongyang's avatar
helloyongyang committed
61
62
63
64
65
66
    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WanTransformerInfer
        elif self.config["feature_caching"] == "Tea":
67
            self.transformer_infer_class = WanTransformerInferTeaCaching
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
68
        elif self.config["feature_caching"] == "TaylorSeer":
69
70
71
72
73
            self.transformer_infer_class = WanTransformerInferTaylorCaching
        elif self.config["feature_caching"] == "Ada":
            self.transformer_infer_class = WanTransformerInferAdaCaching
        elif self.config["feature_caching"] == "Custom":
            self.transformer_infer_class = WanTransformerInferCustomCaching
helloyongyang's avatar
helloyongyang committed
74
        else:
Dongz's avatar
Dongz committed
75
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
helloyongyang's avatar
helloyongyang committed
76

gushiqiao's avatar
Fix  
gushiqiao committed
77
    def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
helloyongyang's avatar
helloyongyang committed
78
        with safe_open(file_path, framework="pt") as f:
gushiqiao's avatar
gushiqiao committed
79
            return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
helloyongyang's avatar
helloyongyang committed
80

gushiqiao's avatar
Fix  
gushiqiao committed
81
    def _load_ckpt(self, use_bf16, skip_bf16):
helloyongyang's avatar
helloyongyang committed
82
83
84
85
        safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)

        if not safetensors_files:
gushiqiao's avatar
gushiqiao committed
86
87
88
89
90
91
            original_pattern = os.path.join(self.model_path, "original", "*.safetensors")
            safetensors_files = glob.glob(original_pattern)

            if not safetensors_files:
                raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")

helloyongyang's avatar
helloyongyang committed
92
93
        weight_dict = {}
        for file_path in safetensors_files:
gushiqiao's avatar
Fix  
gushiqiao committed
94
            file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
helloyongyang's avatar
helloyongyang committed
95
96
97
            weight_dict.update(file_weights)
        return weight_dict

gushiqiao's avatar
Fix  
gushiqiao committed
98
    def _load_quant_ckpt(self, use_bf16, skip_bf16):
99
100
        ckpt_path = self.config.dit_quantized_ckpt
        logger.info(f"Loading quant dit model from {ckpt_path}")
101

gushiqiao's avatar
Fix  
gushiqiao committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
        if not index_files:
            raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")

        index_path = os.path.join(ckpt_path, index_files[0])
        logger.info(f" Using safetensors index: {index_path}")

        with open(index_path, "r") as f:
            index_data = json.load(f)

        weight_dict = {}
        for filename in set(index_data["weight_map"].values()):
            safetensor_path = os.path.join(ckpt_path, filename)
            with safe_open(safetensor_path, framework="pt") as f:
                logger.info(f"Loading weights from {safetensor_path}")
                for k in f.keys():
                    if f.get_tensor(k).dtype == torch.float:
                        if use_bf16 or all(s not in k for s in skip_bf16):
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
                        else:
                            weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
                    else:
                        weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
125

126
127
        return weight_dict

gushiqiao's avatar
Fix  
gushiqiao committed
128
    def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
129
130
        lazy_load_model_path = self.config.dit_quantized_ckpt
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
gushiqiao's avatar
gushiqiao committed
131
        pre_post_weight_dict = {}
132
133

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
gushiqiao's avatar
gushiqiao committed
134
        with safe_open(safetensor_path, framework="pt", device="cpu") as f:
135
            for k in f.keys():
gushiqiao's avatar
Fix  
gushiqiao committed
136
137
138
139
140
141
142
                if f.get_tensor(k).dtype == torch.float:
                    if use_bf16 or all(s not in k for s in skip_bf16):
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
                    else:
                        pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
                else:
                    pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
143

gushiqiao's avatar
gushiqiao committed
144
        return pre_post_weight_dict
145

lijiaqi2's avatar
lijiaqi2 committed
146
    def _init_weights(self, weight_dict=None):
gushiqiao's avatar
Fix  
gushiqiao committed
147
148
        use_bf16 = GET_DTYPE() == "BF16"
        # Some layers run with float32 to achieve high accuracy
gushiqiao's avatar
gushiqiao committed
149
150
151
152
153
154
155
156
        skip_bf16 = {
            "norm",
            "embedding",
            "modulation",
            "time",
            "img_emb.proj.0",
            "img_emb.proj.4",
        }
lijiaqi2's avatar
lijiaqi2 committed
157
        if weight_dict is None:
158
            if not self.dit_quantized or self.weight_auto_quant:
gushiqiao's avatar
Fix  
gushiqiao committed
159
                self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
160
            else:
161
                if not self.config.get("lazy_load", False):
gushiqiao's avatar
Fix  
gushiqiao committed
162
                    self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
163
                else:
gushiqiao's avatar
gushiqiao committed
164
                    self.original_weight_dict = self._load_quant_split_ckpt(use_bf16, skip_bf16)
lijiaqi2's avatar
lijiaqi2 committed
165
166
        else:
            self.original_weight_dict = weight_dict
helloyongyang's avatar
helloyongyang committed
167
168
        # init weights
        self.pre_weight = self.pre_weight_class(self.config)
TorynCurtis's avatar
TorynCurtis committed
169
        self.post_weight = self.post_weight_class(self.config)
helloyongyang's avatar
helloyongyang committed
170
171
        self.transformer_weights = self.transformer_weight_class(self.config)
        # load weights
172
173
        self.pre_weight.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)
gushiqiao's avatar
gushiqiao committed
174
        self.transformer_weights.load(self.original_weight_dict)
helloyongyang's avatar
helloyongyang committed
175
176
177
178
179
180
181
182

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        self.transformer_infer = self.transformer_infer_class(self.config)

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
183
184
        self.pre_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
helloyongyang's avatar
helloyongyang committed
185
186
        self.transformer_infer.set_scheduler(scheduler)

TorynCurtis's avatar
TorynCurtis committed
187
188
189
190
191
192
193
194
195
196
    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

helloyongyang's avatar
helloyongyang committed
197
    @torch.no_grad()
198
    def infer(self, inputs):
gushiqiao's avatar
gushiqiao committed
199
200
201
202
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

203
        embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
gushiqiao's avatar
Fix bug  
gushiqiao committed
204
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
Dongz's avatar
Dongz committed
205
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
206

root's avatar
root committed
207
        self.scheduler.noise_pred = noise_pred_cond
helloyongyang's avatar
helloyongyang committed
208

gushiqiao's avatar
gushiqiao committed
209
210
211
212
        if self.clean_cuda_cache:
            del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
            torch.cuda.empty_cache()

213
        if self.config["enable_cfg"]:
root's avatar
root committed
214
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
gushiqiao's avatar
Fix bug  
gushiqiao committed
215
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
root's avatar
root committed
216
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
helloyongyang's avatar
helloyongyang committed
217

gushiqiao's avatar
gushiqiao committed
218
            self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
gushiqiao's avatar
gushiqiao committed
219

root's avatar
root committed
220
221
222
            if self.config["cpu_offload"]:
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
223
224
225
226

                if self.clean_cuda_cache:
                    del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
                    torch.cuda.empty_cache()