wan_distill_runner.py 2.12 KB
Newer Older
1
import os
PengGao's avatar
PengGao committed
2

3
4
5
6
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
PengGao's avatar
PengGao committed
7
8
from loguru import logger

9
10
11
12
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.distill_model import WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
PengGao's avatar
PengGao committed
13
14
15
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
16
17
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
PengGao's avatar
PengGao committed
18
19
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
20
21
22
23
24
25
26
27
from lightx2v.utils.utils import cache_video


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
28
    def load_transformer(self):
29
        if self.config.get("lora_configs") and self.config.lora_configs:
GoatWu's avatar
GoatWu committed
30
31
32
33
34
            model = WanModel(
                self.config.model_path,
                self.config,
                self.init_device,
            )
35
            lora_wrapper = WanLoraWrapper(model)
36
37
38
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
39
                lora_name = lora_wrapper.load_lora(lora_path)
40
41
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
42
43
        else:
            model = WanDistillModel(self.config.model_path, self.config, self.init_device)
44
45
46
47
48
49
50
51
        return model

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanStepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)