wan_runner.py 18.2 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
14
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
from lightx2v.utils.profiler import ProfilingContext
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36
37


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

38
39
40
41
42
43
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
44
        if self.config.get("lora_configs") and self.config.lora_configs:
45
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
46
            lora_wrapper = WanLoraWrapper(model)
47
48
49
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
50
                lora_name = lora_wrapper.load_lora(lora_path)
51
52
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
53
54
        return model

55
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
56
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
57
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
58
59
60
61
62
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
63
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
64
65
66
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
67
68
69
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
70
71
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
72

73
74
            image_encoder = CLIPModel(
                dtype=torch.float16,
75
                device=self.init_device,
76
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
77
78
79
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
80
            )
81

82
        return image_encoder
helloyongyang's avatar
helloyongyang committed
83

84
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
85
        # offload config
gushiqiao's avatar
gushiqiao committed
86
87
88
89
90
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
91
92
93
94
95
96

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
97
98
99
100
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
101
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
102
103
104
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
105
106
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
107
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
108

helloyongyang's avatar
helloyongyang committed
109
110
111
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
112
            device=t5_device,
113
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
114
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
115
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
116
            cpu_offload=t5_offload,
117
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
118
119
120
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
121
122
        )
        text_encoders = [text_encoder]
123
        return text_encoders
helloyongyang's avatar
helloyongyang committed
124

125
    def load_vae_encoder(self):
126
        vae_config = {
gushiqiao's avatar
gushiqiao committed
127
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
128
            "device": self.init_device,
129
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
130
131
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
132
133
134
135
136
137
138
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
139
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
140
            "device": self.init_device,
141
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
142
143
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
144
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
145
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
146
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
147
                vae_pth=tiny_vae_path,
148
                device=self.init_device,
149
            ).to("cuda")
150
        else:
151
            vae_decoder = WanVAE(**vae_config)
152
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
153

154
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
155
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
156
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
157
158
159
160
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
161
162

    def init_scheduler(self):
163
164
165
166
167
168
169
170
171
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

172
        if self.config.get("changing_resolution", False):
173
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
174
        else:
175
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
176
177
        self.model.set_scheduler(scheduler)

178
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
179
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
180
            self.text_encoders = self.load_text_encoder()
181
        n_prompt = self.config.get("negative_prompt", "")
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
200
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
201
202
203
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
204

helloyongyang's avatar
helloyongyang committed
205
206
        return text_encoder_output

207
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
208
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
209
            self.image_encoder = self.load_image_encoder()
210
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
211
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
212
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
213
214
215
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
216
217
218
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
219
220
221
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
222
223
224
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
225
226
227

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
228
229
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
230
231
232
233
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
234
235
236
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
237
238
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
239
240
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
241
242

    def get_vae_encoder_output(self, img, lat_h, lat_w):
243
244
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
245

246
247
248
249
250
251
252
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
253
254
255
256
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
257
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
258
            self.vae_encoder = self.load_vae_encoder()
259
        vae_encoder_out = self.vae_encoder.encode(
260
261
262
263
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
264
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
265
266
267
268
                    ],
                    dim=1,
                ).cuda()
            ],
269
            self.config,
helloyongyang's avatar
helloyongyang committed
270
        )[0]
gushiqiao's avatar
gushiqiao committed
271
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
272
273
274
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
275
276
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(torch.bfloat16)
        return vae_encoder_out
277

278
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
279
280
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
281
            "vae_encoder_out": vae_encoder_out,
282
        }
283
284
285
286
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
287
288

    def set_target_shape(self):
289
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
290
        if self.config.task == "i2v":
291
292
            self.config.target_shape = (
                num_channels_latents,
293
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
294
295
296
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
297
298
        elif self.config.task == "t2v":
            self.config.target_shape = (
299
                num_channels_latents,
300
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
301
302
303
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
304
305

    def save_video_func(self, images):
306
307
308
309
310
311
312
313
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
314
315
316
317
318
319
320
321
322
323
324
325


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
326
327
328
329
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
        high_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
        low_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
437
        z = self.vae_encoder.encode(img, self.config)
438
        return z