wan_runner.py 13.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
gushiqiao's avatar
gushiqiao committed
10
11
12
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
    WanScheduler4ChangingResolution,
)
13
14
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
    WanSchedulerTeaCaching,
15
16
17
    WanSchedulerTaylorCaching,
    WanSchedulerAdaCaching,
    WanSchedulerCustomCaching,
Rongjin Yang's avatar
Rongjin Yang committed
18
19
20
    WanSchedulerFirstBlock,
    WanSchedulerDualBlock,
    WanSchedulerDynamicBlock,
21
)
helloyongyang's avatar
helloyongyang committed
22
from lightx2v.utils.profiler import ProfilingContext
23
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
24
25
26
27
28
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
29
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
30
from lightx2v.utils.utils import cache_video
root's avatar
root committed
31
from loguru import logger
helloyongyang's avatar
helloyongyang committed
32
33
34
35
36
37
38


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

39
40
41
42
43
44
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
45
        if self.config.get("lora_configs") and self.config.lora_configs:
46
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
47
            lora_wrapper = WanLoraWrapper(model)
48
49
50
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
51
                lora_name = lora_wrapper.load_lora(lora_path)
52
53
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
54
55
        return model

56
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
57
        image_encoder = None
58
        if self.config.task == "i2v":
gushiqiao's avatar
gushiqiao committed
59
60
61
62
63
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
64
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
65
66
67
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
68
69
70
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
71
72
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
73

74
75
            image_encoder = CLIPModel(
                dtype=torch.float16,
76
                device=self.init_device,
77
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
78
79
80
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
81
            )
82

83
        return image_encoder
helloyongyang's avatar
helloyongyang committed
84

85
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
86
        # offload config
gushiqiao's avatar
gushiqiao committed
87
88
89
90
91
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
92
93
94
95
96
97

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
98
99
100
101
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
102
103
104
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
105
106
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
Fix  
gushiqiao committed
107

helloyongyang's avatar
helloyongyang committed
108
109
110
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
111
            device=t5_device,
112
            checkpoint_path=t5_original_ckpt,
helloyongyang's avatar
helloyongyang committed
113
114
            tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
115
            cpu_offload=t5_offload,
116
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
117
118
119
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
120
121
        )
        text_encoders = [text_encoder]
122
        return text_encoders
helloyongyang's avatar
helloyongyang committed
123

124
    def load_vae_encoder(self):
125
        vae_config = {
126
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
127
            "device": self.init_device,
128
129
130
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
131
132
133
134
135
136
137
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
138
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth", "original"),
139
140
141
142
            "device": self.init_device,
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
143
        if self.config.get("use_tiny_vae", False):
144
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth", "original")
145
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
146
                vae_pth=tiny_vae_path,
147
                device=self.init_device,
148
            ).to("cuda")
149
        else:
150
            vae_decoder = WanVAE(**vae_config)
151
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
152

153
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
154
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
155
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
156
157
158
159
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
160
161

    def init_scheduler(self):
162
163
        if self.config.get("changing_resolution", False):
            scheduler = WanScheduler4ChangingResolution(self.config)
helloyongyang's avatar
helloyongyang committed
164
        else:
165
166
167
168
169
170
171
172
173
174
            if self.config.feature_caching == "NoCaching":
                scheduler = WanScheduler(self.config)
            elif self.config.feature_caching == "Tea":
                scheduler = WanSchedulerTeaCaching(self.config)
            elif self.config.feature_caching == "TaylorSeer":
                scheduler = WanSchedulerTaylorCaching(self.config)
            elif self.config.feature_caching == "Ada":
                scheduler = WanSchedulerAdaCaching(self.config)
            elif self.config.feature_caching == "Custom":
                scheduler = WanSchedulerCustomCaching(self.config)
Rongjin Yang's avatar
Rongjin Yang committed
175
176
177
178
179
180
            elif self.config.feature_caching == "FirstBlock":
                scheduler = WanSchedulerFirstBlock(self.config)
            elif self.config.feature_caching == "DualBlock":
                scheduler = WanSchedulerDualBlock(self.config)
            elif self.config.feature_caching == "DynamicBlock":
                scheduler = WanSchedulerDynamicBlock(self.config)
181
182
            else:
                raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
183
184
        self.model.set_scheduler(scheduler)

185
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
186
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
187
            self.text_encoders = self.load_text_encoder()
helloyongyang's avatar
helloyongyang committed
188
        text_encoder_output = {}
189
190
191
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text])
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
gushiqiao's avatar
gushiqiao committed
192
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
193
194
195
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
196
197
198
199
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

200
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
201
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
202
            self.image_encoder = self.load_image_encoder()
203
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
204
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
205
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
206
207
208
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
209
210
211
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
212
213
214
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
215
216
217
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
218
219
220

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
221
222
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
223
224
225
226
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
227
228
229
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
230
231
232
233
234
235
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
            vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encode_out

    def get_vae_encoder_output(self, img, lat_h, lat_w):
236
237
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
238

239
240
241
242
243
244
245
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
246
247
248
249
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
250
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
251
            self.vae_encoder = self.load_vae_encoder()
252
        vae_encode_out = self.vae_encoder.encode(
253
254
255
256
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
257
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
258
259
260
261
                    ],
                    dim=1,
                ).cuda()
            ],
262
            self.config,
helloyongyang's avatar
helloyongyang committed
263
        )[0]
gushiqiao's avatar
gushiqiao committed
264
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
265
266
267
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
268
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
269
        return vae_encode_out
270
271

    def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
272
273
274
275
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
            "vae_encode_out": vae_encode_out,
        }
276
277
278
279
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
280
281

    def set_target_shape(self):
282
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
283
        if self.config.task == "i2v":
284
285
            self.config.target_shape = (
                num_channels_latents,
286
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
287
288
289
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
290
291
        elif self.config.task == "t2v":
            self.config.target_shape = (
292
                num_channels_latents,
293
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
294
295
296
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
297
298

    def save_video_func(self, images):
299
300
301
302
303
304
305
306
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )