wan_distill_runner.py 6.15 KB
Newer Older
1
2
import os

PengGao's avatar
PengGao committed
3
4
from loguru import logger

5
from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel
6
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
7
8
9
from lightx2v.models.networks.wan.model import Wan22MoeModel, WanModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
PengGao's avatar
PengGao committed
10
from lightx2v.utils.registry_factory import RUNNER_REGISTER
11
12
13
14
15
16
17


@RUNNER_REGISTER("wan2.1_distill")
class WanDistillRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
18
    def load_transformer(self):
19
        if self.config.get("lora_configs") and self.config.lora_configs:
GoatWu's avatar
GoatWu committed
20
21
22
23
24
            model = WanModel(
                self.config.model_path,
                self.config,
                self.init_device,
            )
25
            lora_wrapper = WanLoraWrapper(model)
26
27
28
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
29
                lora_name = lora_wrapper.load_lora(lora_path)
30
31
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
GoatWu's avatar
GoatWu committed
32
33
        else:
            model = WanDistillModel(self.config.model_path, self.config, self.init_device)
34
35
36
37
38
39
40
41
        return model

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanStepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


class MultiDistillModelStruct(MultiModelStruct):
    def __init__(self, model_list, config, boundary_step_index=2):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary_step_index = boundary_step_index
        self.cur_model_index = -1
        logger.info(f"boundary step index: {self.boundary_step_index}")

    def get_current_model_index(self):
        if self.scheduler.step_index < self.boundary_step_index:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1


@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        use_high_lora, use_low_lora = False, False
        if self.config.get("lora_configs") and self.config.lora_configs:
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "high_noise_model":
                    use_high_lora = True
                elif lora_config.get("name", "") == "low_noise_model":
                    use_low_lora = True

        if use_high_lora:
            high_noise_model = Wan22MoeModel(
                os.path.join(self.config.model_path, "high_noise_model"),
                self.config,
                self.init_device,
            )
            high_lora_wrapper = WanLoraWrapper(high_noise_model)
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "high_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = high_lora_wrapper.load_lora(lora_path)
                    high_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
            high_noise_model = Wan22MoeDistillModel(
                os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
                self.config,
                self.init_device,
            )

        if use_low_lora:
            low_noise_model = Wan22MoeModel(
                os.path.join(self.config.model_path, "low_noise_model"),
                self.config,
                self.init_device,
            )
            low_lora_wrapper = WanLoraWrapper(low_noise_model)
            for lora_config in self.config.lora_configs:
                if lora_config.get("name", "") == "low_noise_model":
                    lora_path = lora_config["path"]
                    strength = lora_config.get("strength", 1.0)
                    lora_name = low_lora_wrapper.load_lora(lora_path)
                    low_lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
        else:
            low_noise_model = Wan22MoeDistillModel(
                os.path.join(self.config.model_path, "distill_models", "low_noise_model"),
                self.config,
                self.init_device,
            )

        return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary_step_index)

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = Wan22StepDistillScheduler(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)