wan_runner.py 17.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
import numpy as np
import torch
import torchvision.transforms.functional as TF
6
7
import torch.distributed as dist
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
10
11
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
gushiqiao's avatar
gushiqiao committed
12
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
13
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
14
)
15
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
16
    WanSchedulerCaching,
17
    WanSchedulerTaylorCaching,
18
)
19
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
20
21
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
22
from lightx2v.models.networks.wan.model import WanModel, Wan22MoeModel
helloyongyang's avatar
helloyongyang committed
23
24
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
28
from lightx2v.utils.utils import cache_video, best_output_size
from lightx2v.utils.profiler import ProfilingContext
helloyongyang's avatar
helloyongyang committed
29
30
31
32
33
34
35


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

36
37
38
39
40
41
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
42
        if self.config.get("lora_configs") and self.config.lora_configs:
43
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
44
            lora_wrapper = WanLoraWrapper(model)
45
46
47
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
48
                lora_name = lora_wrapper.load_lora(lora_path)
49
50
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
51
52
        return model

53
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
54
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
55
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
56
57
58
59
60
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
61
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
62
63
64
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
65
66
67
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
68
69
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
70

71
72
            image_encoder = CLIPModel(
                dtype=torch.float16,
73
                device=self.init_device,
74
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
75
76
77
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
78
            )
79

80
        return image_encoder
helloyongyang's avatar
helloyongyang committed
81

82
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
83
        # offload config
gushiqiao's avatar
gushiqiao committed
84
85
86
87
88
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
89
90
91
92
93
94

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
95
96
97
98
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
99
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
100
101
102
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
103
104
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
105
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
106

helloyongyang's avatar
helloyongyang committed
107
108
109
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
110
            device=t5_device,
111
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
112
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
113
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
114
            cpu_offload=t5_offload,
115
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
116
117
118
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
119
120
        )
        text_encoders = [text_encoder]
121
        return text_encoders
helloyongyang's avatar
helloyongyang committed
122

123
    def load_vae_encoder(self):
124
        vae_config = {
gushiqiao's avatar
gushiqiao committed
125
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
126
            "device": self.init_device,
127
128
129
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
130
131
132
133
134
135
136
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
137
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
138
139
140
141
            "device": self.init_device,
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
142
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
143
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
144
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
145
                vae_pth=tiny_vae_path,
146
                device=self.init_device,
147
            ).to("cuda")
148
        else:
149
            vae_decoder = WanVAE(**vae_config)
150
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
151

152
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
153
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
154
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
155
156
157
158
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
159
160

    def init_scheduler(self):
161
162
163
164
165
166
167
168
169
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

170
        if self.config.get("changing_resolution", False):
171
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
172
        else:
173
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
174
175
        self.model.set_scheduler(scheduler)

176
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
177
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
178
            self.text_encoders = self.load_text_encoder()
helloyongyang's avatar
helloyongyang committed
179
        text_encoder_output = {}
180
181
182
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text])
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
gushiqiao's avatar
gushiqiao committed
183
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
184
185
186
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
187
188
189
190
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

191
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
192
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
193
            self.image_encoder = self.load_image_encoder()
194
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
195
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
196
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
197
198
199
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
200
201
202
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
203
204
205
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
206
207
208
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
209
210
211

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
212
213
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
214
215
216
217
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
218
219
220
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
221
222
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
223
224
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
225
226

    def get_vae_encoder_output(self, img, lat_h, lat_w):
227
228
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
229

230
231
232
233
234
235
236
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
237
238
239
240
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
241
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
242
            self.vae_encoder = self.load_vae_encoder()
243
        vae_encoder_out = self.vae_encoder.encode(
244
245
246
247
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
248
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
249
250
251
252
                    ],
                    dim=1,
                ).cuda()
            ],
253
            self.config,
helloyongyang's avatar
helloyongyang committed
254
        )[0]
gushiqiao's avatar
gushiqiao committed
255
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
256
257
258
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
259
260
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(torch.bfloat16)
        return vae_encoder_out
261

262
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
263
264
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
265
            "vae_encoder_out": vae_encoder_out,
266
        }
267
268
269
270
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
271
272

    def set_target_shape(self):
273
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
274
        if self.config.task == "i2v":
275
276
            self.config.target_shape = (
                num_channels_latents,
277
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
278
279
280
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
281
282
        elif self.config.task == "t2v":
            self.config.target_shape = (
283
                num_channels_latents,
284
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
285
286
287
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
288
289

    def save_video_func(self, images):
290
291
292
293
294
295
296
297
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
298
299
300
301
302
303
304
305
306
307
308
309


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
310
311
312
313
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
        high_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
        low_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
        z = self.vae_encoder.encode(img)
        return z