wan_runner.py 18.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
import numpy as np
import torch
import torchvision.transforms.functional as TF
6
7
import torch.distributed as dist
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
10
11
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
gushiqiao's avatar
gushiqiao committed
12
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
13
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
14
)
15
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
16
    WanSchedulerCaching,
17
    WanSchedulerTaylorCaching,
18
)
19
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
20
21
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
22
from lightx2v.models.networks.wan.model import WanModel, Wan22MoeModel
helloyongyang's avatar
helloyongyang committed
23
24
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
28
from lightx2v.utils.utils import cache_video, best_output_size
from lightx2v.utils.profiler import ProfilingContext
helloyongyang's avatar
helloyongyang committed
29
30
31
32
33
34
35


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

36
37
38
39
40
41
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
42
        if self.config.get("lora_configs") and self.config.lora_configs:
43
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
44
            lora_wrapper = WanLoraWrapper(model)
45
46
47
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
48
                lora_name = lora_wrapper.load_lora(lora_path)
49
50
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
51
52
        return model

53
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
54
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
55
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
56
57
58
59
60
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
61
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
62
63
64
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
65
66
67
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
68
69
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
70

71
72
            image_encoder = CLIPModel(
                dtype=torch.float16,
73
                device=self.init_device,
74
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
75
76
77
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
78
            )
79

80
        return image_encoder
helloyongyang's avatar
helloyongyang committed
81

82
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
83
        # offload config
gushiqiao's avatar
gushiqiao committed
84
85
86
87
88
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
89
90
91
92
93
94

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
95
96
97
98
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
99
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
100
101
102
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
103
104
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
105
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
106

helloyongyang's avatar
helloyongyang committed
107
108
109
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
110
            device=t5_device,
111
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
112
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
113
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
114
            cpu_offload=t5_offload,
115
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
116
117
118
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
119
120
        )
        text_encoders = [text_encoder]
121
        return text_encoders
helloyongyang's avatar
helloyongyang committed
122

123
    def load_vae_encoder(self):
124
        vae_config = {
gushiqiao's avatar
gushiqiao committed
125
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
126
            "device": self.init_device,
127
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
128
129
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
130
131
132
133
134
135
136
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
137
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
138
            "device": self.init_device,
139
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
140
141
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
142
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
143
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
144
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
145
                vae_pth=tiny_vae_path,
146
                device=self.init_device,
147
            ).to("cuda")
148
        else:
149
            vae_decoder = WanVAE(**vae_config)
150
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
151

152
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
153
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
154
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
155
156
157
158
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
159
160

    def init_scheduler(self):
161
162
163
164
165
166
167
168
169
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

170
        if self.config.get("changing_resolution", False):
171
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
172
        else:
173
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
174
175
        self.model.set_scheduler(scheduler)

176
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
177
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
178
            self.text_encoders = self.load_text_encoder()
179
        n_prompt = self.config.get("negative_prompt", "")
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
198
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
199
200
201
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
202

helloyongyang's avatar
helloyongyang committed
203
204
        return text_encoder_output

205
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
206
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
207
            self.image_encoder = self.load_image_encoder()
208
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
209
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
210
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
211
212
213
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
214
215
216
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
217
218
219
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
220
221
222
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
223
224
225

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
226
227
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
228
229
230
231
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
232
233
234
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
235
236
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
237
238
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
239
240

    def get_vae_encoder_output(self, img, lat_h, lat_w):
241
242
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
243

244
245
246
247
248
249
250
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
251
252
253
254
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
255
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
256
            self.vae_encoder = self.load_vae_encoder()
257
        vae_encoder_out = self.vae_encoder.encode(
258
259
260
261
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
262
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
263
264
265
266
                    ],
                    dim=1,
                ).cuda()
            ],
267
            self.config,
helloyongyang's avatar
helloyongyang committed
268
        )[0]
gushiqiao's avatar
gushiqiao committed
269
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
270
271
272
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
273
274
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(torch.bfloat16)
        return vae_encoder_out
275

276
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
277
278
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
279
            "vae_encoder_out": vae_encoder_out,
280
        }
281
282
283
284
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
285
286

    def set_target_shape(self):
287
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
288
        if self.config.task == "i2v":
289
290
            self.config.target_shape = (
                num_channels_latents,
291
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
292
293
294
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
295
296
        elif self.config.task == "t2v":
            self.config.target_shape = (
297
                num_channels_latents,
298
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
299
300
301
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
302
303

    def save_video_func(self, images):
304
305
306
307
308
309
310
311
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
312
313
314
315
316
317
318
319
320
321
322
323


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
324
325
326
327
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
        high_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
        low_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
435
        z = self.vae_encoder.encode(img, self.config)
436
        return z