wan_runner.py 24.6 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
from loguru import logger

11
12
from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel
PengGao's avatar
PengGao committed
13
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
yihuiwen's avatar
yihuiwen committed
27
from lightx2v.server.metrics import monitor_cli
28
from lightx2v.utils.envs import *
29
from lightx2v.utils.profiler import *
PengGao's avatar
PengGao committed
30
31
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
32
from lightx2v_platform.base.global_var import AI_DEVICE
helloyongyang's avatar
helloyongyang committed
33
34
35
36
37
38


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
39
40
        self.vae_cls = WanVAE
        self.tiny_vae_cls = WanVAE_tiny
gushiqiao's avatar
gushiqiao committed
41
        self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth")
42
        self.tiny_vae_name = "taew2_1.pth"
helloyongyang's avatar
helloyongyang committed
43

44
45
    def load_transformer(self):
        model = WanModel(
46
            self.config["model_path"],
47
48
49
            self.config,
            self.init_device,
        )
50
        if self.config.get("lora_configs") and self.config.lora_configs:
51
            assert not self.config.get("dit_quantized", False)
52
            lora_wrapper = WanLoraWrapper(model)
53
54
55
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
56
                lora_name = lora_wrapper.load_lora(lora_path)
57
58
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
        return model

61
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
62
        image_encoder = None
63
        if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
64
65
66
67
68
            # offload config
            clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
            if clip_offload:
                clip_device = torch.device("cpu")
            else:
69
                clip_device = torch.device(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
70
71
72
73
74
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
75
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
gushiqiao's avatar
gushiqiao committed
76
                clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth"
77
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
78
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
79
80
81
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
82
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
83
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
84

85
86
            image_encoder = CLIPModel(
                dtype=torch.float16,
gushiqiao's avatar
gushiqiao committed
87
                device=clip_device,
88
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
89
90
91
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
gushiqiao's avatar
gushiqiao committed
92
                cpu_offload=clip_offload,
gushiqiao's avatar
gushiqiao committed
93
                use_31_block=self.config.get("use_31_block", True),
94
                load_from_rank0=self.config.get("load_from_rank0", False),
95
            )
96

97
        return image_encoder
helloyongyang's avatar
helloyongyang committed
98

99
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
100
        # offload config
101
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
102
103
104
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
105
            t5_device = torch.device(AI_DEVICE)
106
        tokenizer_path = os.path.join(self.config["model_path"], "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
107
108
109
110
111
        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
112
113
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
114
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
115
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
116
117
118
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
119
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
120
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
Fix  
gushiqiao committed
121

helloyongyang's avatar
helloyongyang committed
122
123
124
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
125
            device=t5_device,
126
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
127
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
128
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
129
            cpu_offload=t5_offload,
gushiqiao's avatar
gushiqiao committed
130
131
132
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
133
            load_from_rank0=self.config.get("load_from_rank0", False),
helloyongyang's avatar
helloyongyang committed
134
135
        )
        text_encoders = [text_encoder]
136
        return text_encoders
helloyongyang's avatar
helloyongyang committed
137

138
    def load_vae_encoder(self):
139
140
141
142
143
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
144
            vae_device = torch.device(AI_DEVICE)
145

146
        vae_config = {
gushiqiao's avatar
gushiqiao committed
147
            "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
148
            "device": vae_device,
149
            "parallel": self.config["parallel"],
150
            "use_tiling": self.config.get("use_tiling_vae", False),
151
            "cpu_offload": vae_offload,
152
            "dtype": GET_DTYPE(),
153
            "load_from_rank0": self.config.get("load_from_rank0", False),
gushiqiao's avatar
gushiqiao committed
154
            "use_lightvae": self.config.get("use_lightvae", False),
155
        }
156
        if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
157
158
            return None
        else:
159
            return self.vae_cls(**vae_config)
160
161

    def load_vae_decoder(self):
162
163
164
165
166
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
Kane's avatar
Kane committed
167
            vae_device = torch.device(self.init_device)
168

169
        vae_config = {
gushiqiao's avatar
gushiqiao committed
170
            "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
171
            "device": vae_device,
172
            "parallel": self.config["parallel"],
173
            "use_tiling": self.config.get("use_tiling_vae", False),
174
            "cpu_offload": vae_offload,
gushiqiao's avatar
gushiqiao committed
175
            "use_lightvae": self.config.get("use_lightvae", False),
176
            "dtype": GET_DTYPE(),
177
            "load_from_rank0": self.config.get("load_from_rank0", False),
178
        }
gushiqiao's avatar
gushiqiao committed
179
        if self.config.get("use_tae", False):
gushiqiao's avatar
gushiqiao committed
180
181
            tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
            vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
182
        else:
183
            vae_decoder = self.vae_cls(**vae_config)
184
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
185

186
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
187
        vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
188
        if vae_encoder is None or self.config.get("use_tae", False):
gushiqiao's avatar
gushiqiao committed
189
190
191
192
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
193
194

    def init_scheduler(self):
195
        if self.config["feature_caching"] == "NoCaching":
196
            scheduler_class = WanScheduler
197
        elif self.config["feature_caching"] == "TaylorSeer":
198
            scheduler_class = WanSchedulerTaylorCaching
Musisoul's avatar
Musisoul committed
199
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
200
201
202
203
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

204
        if self.config.get("changing_resolution", False):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
205
            self.scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
206
        else:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
207
            self.scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
208

yihuiwen's avatar
yihuiwen committed
209
210
211
212
213
214
    @ProfilingContext4DebugL1(
        "Run Text Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
        metrics_labels=["WanRunner"],
    )
215
    def run_text_encoder(self, input_info):
gushiqiao's avatar
gushiqiao committed
216
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
217
            self.text_encoders = self.load_text_encoder()
218
219

        prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
yihuiwen's avatar
yihuiwen committed
220
221
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(prompt))
222
        neg_prompt = input_info.negative_prompt
223

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
224
        if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]:
225
226
227
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
228
                context = self.text_encoders[0].infer([prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
229
                context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
230
231
                text_encoder_output = {"context": context}
            else:
232
                context_null = self.text_encoders[0].infer([neg_prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
233
                context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
234
235
                text_encoder_output = {"context_null": context_null}
        else:
236
            context = self.text_encoders[0].infer([prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
237
            context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
238
239
240
241
242
            if self.config.get("enable_cfg", False):
                context_null = self.text_encoders[0].infer([neg_prompt])
                context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
            else:
                context_null = None
243
244
245
246
247
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
248
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
249
250
251
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
252

helloyongyang's avatar
helloyongyang committed
253
254
        return text_encoder_output

yihuiwen's avatar
yihuiwen committed
255
256
257
258
259
260
    @ProfilingContext4DebugL1(
        "Run Image Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
        metrics_labels=["WanRunner"],
    )
gushiqiao's avatar
gushiqiao committed
261
    def run_image_encoder(self, first_frame, last_frame=None):
gushiqiao's avatar
gushiqiao committed
262
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
263
            self.image_encoder = self.load_image_encoder()
gushiqiao's avatar
gushiqiao committed
264
        if last_frame is None:
helloyongyang's avatar
helloyongyang committed
265
            clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
266
        else:
helloyongyang's avatar
helloyongyang committed
267
            clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
268
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
269
270
271
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
272
273
        return clip_encoder_out

yihuiwen's avatar
yihuiwen committed
274
275
276
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
277
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
yihuiwen's avatar
yihuiwen committed
278
279
        metrics_labels=["WanRunner"],
    )
gushiqiao's avatar
gushiqiao committed
280
    def run_vae_encoder(self, first_frame, last_frame=None):
helloyongyang's avatar
helloyongyang committed
281
        h, w = first_frame.shape[2:]
helloyongyang's avatar
helloyongyang committed
282
        aspect_ratio = h / w
283
284
285
286
        max_area = self.config["target_height"] * self.config["target_width"]
        latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
        latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2])
        latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)  # Important: latent_shape is used to set the input_info
287
288

        if self.config.get("changing_resolution", False):
gushiqiao's avatar
gushiqiao committed
289
            assert last_frame is None
290
291
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
292
293
294
                latent_h_tmp, latent_w_tmp = (
                    int(latent_h * self.config["resolution_rate"][i]) // 2 * 2,
                    int(latent_w * self.config["resolution_rate"][i]) // 2 * 2,
295
                )
296
297
298
                vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp))
            vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w))
            return vae_encode_out_list, latent_shape
299
        else:
gushiqiao's avatar
gushiqiao committed
300
            if last_frame is not None:
helloyongyang's avatar
helloyongyang committed
301
302
                first_frame_size = first_frame.shape[2:]
                last_frame_size = last_frame.shape[2:]
gushiqiao's avatar
gushiqiao committed
303
304
305
306
307
308
309
                if first_frame_size != last_frame_size:
                    last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
                    last_frame_size = [
                        round(last_frame_size[0] * last_frame_resize_ratio),
                        round(last_frame_size[1] * last_frame_resize_ratio),
                    ]
                    last_frame = TF.center_crop(last_frame, last_frame_size)
310
311
            vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame)
            return vae_encoder_out, latent_shape
312

gushiqiao's avatar
gushiqiao committed
313
    def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
314
315
        h = lat_h * self.config["vae_stride"][1]
        w = lat_w * self.config["vae_stride"][2]
316
317
        msk = torch.ones(
            1,
318
            self.config["target_video_length"],
319
320
            lat_h,
            lat_w,
321
            device=torch.device(AI_DEVICE),
322
        )
gushiqiao's avatar
gushiqiao committed
323
324
325
326
327
        if last_frame is not None:
            msk[:, 1:-1] = 0
        else:
            msk[:, 1:] = 0

helloyongyang's avatar
helloyongyang committed
328
329
330
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
331

gushiqiao's avatar
gushiqiao committed
332
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
333
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
334
335
336
337

        if last_frame is not None:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
338
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
339
                    torch.zeros(3, self.config["target_video_length"] - 2, h, w),
helloyongyang's avatar
helloyongyang committed
340
                    torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
341
342
                ],
                dim=1,
343
            ).to(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
344
345
346
        else:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
347
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
348
                    torch.zeros(3, self.config["target_video_length"] - 1, h, w),
gushiqiao's avatar
gushiqiao committed
349
350
                ],
                dim=1,
351
            ).to(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
352

353
        vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
354

gushiqiao's avatar
gushiqiao committed
355
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
356
357
358
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
359
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
360
        return vae_encoder_out
361

gushiqiao's avatar
gushiqiao committed
362
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None):
363
364
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
365
            "vae_encoder_out": vae_encoder_out,
366
        }
367
368
369
370
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
371

372
373
374
375
376
377
378
379
380
    def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
        latent_shape = [
            self.config.get("num_channels_latents", 16),
            (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
            latent_h,
            latent_w,
        ]
        return latent_shape

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
381
    def get_latent_shape_with_target_hw(self):
382
383
384
        latent_shape = [
            self.config.get("num_channels_latents", 16),
            (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
385
386
            int(self.config["target_height"]) // self.config["vae_stride"][1],
            int(self.config["target_width"]) // self.config["vae_stride"][2],
387
388
        ]
        return latent_shape
helloyongyang's avatar
helloyongyang committed
389
390
391
392
393
394
395
396
397
398
399
400


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
401
402
403
404
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
405
406
407
408
409
410
411
412
413
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

414
    @ProfilingContext4DebugL2("Swtich models in infer_main costs")
helloyongyang's avatar
helloyongyang committed
415
416
417
    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
418
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
419
420
421
422
423
424
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=0)
                elif self.cur_model_index == 1:  # 1 -> 0
                    self.offload_cpu(model_index=1)
                    self.to_cuda(model_index=0)
helloyongyang's avatar
helloyongyang committed
425
426
427
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
428
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
429
430
431
432
433
434
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=1)
                elif self.cur_model_index == 0:  # 0 -> 1
                    self.offload_cpu(model_index=0)
                    self.to_cuda(model_index=1)
helloyongyang's avatar
helloyongyang committed
435
436
437
438
439
440
441
442
443
444
445
446
447
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
        if not os.path.isdir(self.high_noise_model_path):
            self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
        elif self.config.get("high_noise_original_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_original_ckpt"]

        self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
        if not os.path.isdir(self.low_noise_model_path):
            self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
        elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_original_ckpt"]
helloyongyang's avatar
helloyongyang committed
463
464
465

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
466
        high_noise_model = WanModel(
467
            self.high_noise_model_path,
helloyongyang's avatar
helloyongyang committed
468
469
            self.config,
            self.init_device,
470
            model_type="wan2.2_moe_high_noise",
helloyongyang's avatar
helloyongyang committed
471
        )
helloyongyang's avatar
helloyongyang committed
472
        low_noise_model = WanModel(
473
            self.low_noise_model_path,
helloyongyang's avatar
helloyongyang committed
474
475
            self.config,
            self.init_device,
476
            model_type="wan2.2_moe_low_noise",
helloyongyang's avatar
helloyongyang committed
477
        )
478

479
        if self.config.get("lora_configs") and self.config["lora_configs"]:
480
            assert not self.config.get("dit_quantized", False)
481

482
            for lora_config in self.config["lora_configs"]:
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                base_name = os.path.basename(lora_path)
                if base_name.startswith("high"):
                    lora_wrapper = WanLoraWrapper(high_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
                elif base_name.startswith("low"):
                    lora_wrapper = WanLoraWrapper(low_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
                else:
                    raise ValueError(f"Unsupported LoRA path: {lora_path}")

499
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
500
501
502
503
504
505


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
506
        self.vae_encoder_need_img_original = True
507
508
509
510
        self.vae_cls = Wan2_2_VAE
        self.tiny_vae_cls = Wan2_2_VAE_tiny
        self.vae_name = "Wan2.2_VAE.pth"
        self.tiny_vae_name = "taew2_2.pth"
511

yihuiwen's avatar
yihuiwen committed
512
513
514
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
515
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
yihuiwen's avatar
yihuiwen committed
516
517
        metrics_labels=["Wan22DenseRunner"],
    )
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
534
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(AI_DEVICE).unsqueeze(1)
535
        vae_encoder_out = self.get_vae_encoder_output(img)
536
537
538
        latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
        latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
        return vae_encoder_out, latent_shape
539
540

    def get_vae_encoder_output(self, img):
541
        z = self.vae_encoder.encode(img.unsqueeze(0).to(GET_DTYPE()))
542
        return z