wan_distill_runner.py 7.71 KB
Newer Older
1
2
import os

PengGao's avatar
PengGao committed
3
4
from loguru import logger

5
from lightx2v.models.networks.wan.distill_model import WanDistillModel
6
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.networks.wan.model import WanModel
8
9
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
10
from lightx2v.utils.profiler import *
PengGao's avatar
PengGao committed
11
from lightx2v.utils.registry_factory import RUNNER_REGISTER
12
13
14
15
16
17
18


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
19
    def load_transformer(self):
20
        if self.config.get("lora_configs") and self.config["lora_configs"]:
GoatWu's avatar
GoatWu committed
21
            model = WanModel(
22
                self.config["model_path"],
GoatWu's avatar
GoatWu committed
23
24
25
                self.config,
                self.init_device,
            )
26
            lora_wrapper = WanLoraWrapper(model)
27
            for lora_config in self.config["lora_configs"]:
28
29
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
30
                lora_name = lora_wrapper.load_lora(lora_path)
31
32
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
33
        else:
34
            model = WanDistillModel(self.config["model_path"], self.config, self.init_device)
35
36
37
        return model

    def init_scheduler(self):
38
        if self.config["feature_caching"] == "NoCaching":
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
39
            self.scheduler = WanStepDistillScheduler(self.config)
40
        else:
41
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
42
43
44
45
46
47
48
49
50
51
52


class MultiDistillModelStruct(MultiModelStruct):
    def __init__(self, model_list, config, boundary_step_index=2):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary_step_index = boundary_step_index
        self.cur_model_index = -1
        logger.info(f"boundary step index: {self.boundary_step_index}")

53
    @ProfilingContext4DebugL2("Swtich models in infer_main costs")
54
55
56
    def get_current_model_index(self):
        if self.scheduler.step_index < self.boundary_step_index:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
57
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
gushiqiao's avatar
gushiqiao committed
58
59
60
61
62
63
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=0)
                elif self.cur_model_index == 1:  # 1 -> 0
                    self.offload_cpu(model_index=1)
                    self.to_cuda(model_index=0)
64
65
66
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
67
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
gushiqiao's avatar
gushiqiao committed
68
69
70
71
72
73
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=1)
                elif self.cur_model_index == 0:  # 0 -> 1
                    self.offload_cpu(model_index=0)
                    self.to_cuda(model_index=1)
74
75
76
77
78
79
80
            self.cur_model_index = 1


@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
    def __init__(self, config):
        super().__init__(config)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
        if not os.path.isdir(self.high_noise_model_path):
            self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
        elif self.config.get("high_noise_original_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_original_ckpt"]

        self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
        if not os.path.isdir(self.low_noise_model_path):
            self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
        elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_original_ckpt"]
96
97
98

    def load_transformer(self):
        use_high_lora, use_low_lora = False, False
99
100
        if self.config.get("lora_configs") and self.config["lora_configs"]:
            for lora_config in self.config["lora_configs"]:
101
102
103
104
105
106
                if lora_config.get("name", "") == "high_noise_model":
                    use_high_lora = True
                elif lora_config.get("name", "") == "low_noise_model":
                    use_low_lora = True

        if use_high_lora:
helloyongyang's avatar
helloyongyang committed
107
            high_noise_model = WanModel(
108
                self.high_noise_model_path,
109
110
                self.config,
                self.init_device,
111
                model_type="wan2.2_moe_high_noise",
112
113
            )
            high_lora_wrapper = WanLoraWrapper(high_noise_model)
114
            for lora_config in self.config["lora_configs"]:
115
116
117
118
119
120
121
                if lora_config.get("name", "") == "high_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = high_lora_wrapper.load_lora(lora_path)
                    high_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
122
            high_noise_model = WanDistillModel(
123
                self.high_noise_model_path,
124
125
                self.config,
                self.init_device,
126
                model_type="wan2.2_moe_high_noise",
127
128
129
            )

        if use_low_lora:
helloyongyang's avatar
helloyongyang committed
130
            low_noise_model = WanModel(
131
                self.low_noise_model_path,
132
133
                self.config,
                self.init_device,
134
                model_type="wan2.2_moe_low_noise",
135
136
            )
            low_lora_wrapper = WanLoraWrapper(low_noise_model)
137
            for lora_config in self.config["lora_configs"]:
138
139
140
141
142
143
144
                if lora_config.get("name", "") == "low_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = low_lora_wrapper.load_lora(lora_path)
                    low_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
145
            low_noise_model = WanDistillModel(
146
                self.low_noise_model_path,
147
148
                self.config,
                self.init_device,
149
                model_type="wan2.2_moe_low_noise",
150
151
            )

152
        return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
153
154

    def init_scheduler(self):
155
        if self.config["feature_caching"] == "NoCaching":
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
156
            self.scheduler = Wan22StepDistillScheduler(self.config)
157
        else:
158
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")