wan_distill_runner.py 1.64 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from loguru import logger

3
4
from lightx2v.models.networks.wan.distill_model import WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
PengGao's avatar
PengGao committed
5
6
7
8
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.registry_factory import RUNNER_REGISTER
9
10
11
12
13
14
15


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
16
    def load_transformer(self):
17
        if self.config.get("lora_configs") and self.config.lora_configs:
GoatWu's avatar
GoatWu committed
18
19
20
21
22
            model = WanModel(
                self.config.model_path,
                self.config,
                self.init_device,
            )
23
            lora_wrapper = WanLoraWrapper(model)
24
25
26
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
27
                lora_name = lora_wrapper.load_lora(lora_path)
28
29
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
30
31
        else:
            model = WanDistillModel(self.config.model_path, self.config, self.init_device)
32
33
34
35
36
37
38
39
        return model

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanStepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)