wan_distill_runner.py 6.56 KB
Newer Older
1
2
import os

PengGao's avatar
PengGao committed
3
4
from loguru import logger

5
from lightx2v.models.networks.wan.distill_model import WanDistillModel
6
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.networks.wan.model import WanModel
8
9
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
10
from lightx2v.utils.profiler import *
PengGao's avatar
PengGao committed
11
from lightx2v.utils.registry_factory import RUNNER_REGISTER
12
13
14
15
16
17
18


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
19
    def load_transformer(self):
20
        if self.config.get("lora_configs") and self.config["lora_configs"]:
GoatWu's avatar
GoatWu committed
21
            model = WanModel(
22
                self.config["model_path"],
GoatWu's avatar
GoatWu committed
23
24
25
                self.config,
                self.init_device,
            )
26
            lora_wrapper = WanLoraWrapper(model)
27
            for lora_config in self.config["lora_configs"]:
28
29
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
30
                lora_name = lora_wrapper.load_lora(lora_path)
31
32
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
33
        else:
34
            model = WanDistillModel(self.config["model_path"], self.config, self.init_device)
35
36
37
        return model

    def init_scheduler(self):
38
        if self.config["feature_caching"] == "NoCaching":
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
39
            self.scheduler = WanStepDistillScheduler(self.config)
40
        else:
41
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
42
43
44
45
46
47
48
49
50
51
52


class MultiDistillModelStruct(MultiModelStruct):
    def __init__(self, model_list, config, boundary_step_index=2):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary_step_index = boundary_step_index
        self.cur_model_index = -1
        logger.info(f"boundary step index: {self.boundary_step_index}")

53
    @ProfilingContext4DebugL2("Swtich models in infer_main costs")
54
55
56
    def get_current_model_index(self):
        if self.scheduler.step_index < self.boundary_step_index:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
57
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
gushiqiao's avatar
gushiqiao committed
58
59
60
61
62
63
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=0)
                elif self.cur_model_index == 1:  # 1 -> 0
                    self.offload_cpu(model_index=1)
                    self.to_cuda(model_index=0)
64
65
66
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
67
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
gushiqiao's avatar
gushiqiao committed
68
69
70
71
72
73
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=1)
                elif self.cur_model_index == 0:  # 0 -> 1
                    self.offload_cpu(model_index=0)
                    self.to_cuda(model_index=1)
74
75
76
77
78
79
80
81
82
83
            self.cur_model_index = 1


@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        use_high_lora, use_low_lora = False, False
84
85
        if self.config.get("lora_configs") and self.config["lora_configs"]:
            for lora_config in self.config["lora_configs"]:
86
87
88
89
90
91
                if lora_config.get("name", "") == "high_noise_model":
                    use_high_lora = True
                elif lora_config.get("name", "") == "low_noise_model":
                    use_low_lora = True

        if use_high_lora:
helloyongyang's avatar
helloyongyang committed
92
            high_noise_model = WanModel(
93
                os.path.join(self.config["model_path"], "high_noise_model"),
94
95
96
97
                self.config,
                self.init_device,
            )
            high_lora_wrapper = WanLoraWrapper(high_noise_model)
98
            for lora_config in self.config["lora_configs"]:
99
100
101
102
103
104
105
                if lora_config.get("name", "") == "high_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = high_lora_wrapper.load_lora(lora_path)
                    high_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
106
            high_noise_model = WanDistillModel(
107
                os.path.join(self.config["model_path"], "distill_models", "high_noise_model"),
108
109
                self.config,
                self.init_device,
110
                ckpt_config_key="dit_distill_ckpt_high",
111
112
113
            )

        if use_low_lora:
helloyongyang's avatar
helloyongyang committed
114
            low_noise_model = WanModel(
115
                os.path.join(self.config["model_path"], "low_noise_model"),
116
117
118
119
                self.config,
                self.init_device,
            )
            low_lora_wrapper = WanLoraWrapper(low_noise_model)
120
            for lora_config in self.config["lora_configs"]:
121
122
123
124
125
126
127
                if lora_config.get("name", "") == "low_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = low_lora_wrapper.load_lora(lora_path)
                    low_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
128
            low_noise_model = WanDistillModel(
129
                os.path.join(self.config["model_path"], "distill_models", "low_noise_model"),
130
131
                self.config,
                self.init_device,
132
                ckpt_config_key="dit_distill_ckpt_low",
133
134
            )

135
        return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
136
137

    def init_scheduler(self):
138
        if self.config["feature_caching"] == "NoCaching":
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
139
            self.scheduler = Wan22StepDistillScheduler(self.config)
140
        else:
141
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")