wan_distill_runner.py 6.13 KB
Newer Older
1
2
import os

PengGao's avatar
PengGao committed
3
4
from loguru import logger

5
from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel
6
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.networks.wan.model import WanModel
8
9
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
PengGao's avatar
PengGao committed
10
from lightx2v.utils.registry_factory import RUNNER_REGISTER
11
12
13
14
15
16
17


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
18
    def load_transformer(self):
19
        if self.config.get("lora_configs") and self.config.lora_configs:
GoatWu's avatar
GoatWu committed
20
21
22
23
24
            model = WanModel(
                self.config.model_path,
                self.config,
                self.init_device,
            )
25
            lora_wrapper = WanLoraWrapper(model)
26
27
28
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
29
                lora_name = lora_wrapper.load_lora(lora_path)
30
31
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
32
33
        else:
            model = WanDistillModel(self.config.model_path, self.config, self.init_device)
34
35
36
37
38
39
40
41
        return model

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanStepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


class MultiDistillModelStruct(MultiModelStruct):
    def __init__(self, model_list, config, boundary_step_index=2):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary_step_index = boundary_step_index
        self.cur_model_index = -1
        logger.info(f"boundary step index: {self.boundary_step_index}")

    def get_current_model_index(self):
        if self.scheduler.step_index < self.boundary_step_index:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1


@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        use_high_lora, use_low_lora = False, False
        if self.config.get("lora_configs") and self.config.lora_configs:
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "high_noise_model":
                    use_high_lora = True
                elif lora_config.get("name", "") == "low_noise_model":
                    use_low_lora = True

        if use_high_lora:
helloyongyang's avatar
helloyongyang committed
89
            high_noise_model = WanModel(
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
                os.path.join(self.config.model_path, "high_noise_model"),
                self.config,
                self.init_device,
            )
            high_lora_wrapper = WanLoraWrapper(high_noise_model)
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "high_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = high_lora_wrapper.load_lora(lora_path)
                    high_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
            high_noise_model = Wan22MoeDistillModel(
                os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
                self.config,
                self.init_device,
            )

        if use_low_lora:
helloyongyang's avatar
helloyongyang committed
110
            low_noise_model = WanModel(
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                os.path.join(self.config.model_path, "low_noise_model"),
                self.config,
                self.init_device,
            )
            low_lora_wrapper = WanLoraWrapper(low_noise_model)
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "low_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = low_lora_wrapper.load_lora(lora_path)
                    low_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
            low_noise_model = Wan22MoeDistillModel(
                os.path.join(self.config.model_path, "distill_models", "low_noise_model"),
                self.config,
                self.init_device,
            )

        return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary_step_index)

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = Wan22StepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)