wan_runner.py 27.7 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
from loguru import logger

11
12
from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel
PengGao's avatar
PengGao committed
13
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
yihuiwen's avatar
yihuiwen committed
27
from lightx2v.server.metrics import monitor_cli
28
from lightx2v.utils.envs import *
29
from lightx2v.utils.profiler import *
PengGao's avatar
PengGao committed
30
31
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
32
from lightx2v_platform.base.global_var import AI_DEVICE
helloyongyang's avatar
helloyongyang committed
33
34
35
36
37
38


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
39
40
        self.vae_cls = WanVAE
        self.tiny_vae_cls = WanVAE_tiny
gushiqiao's avatar
gushiqiao committed
41
        self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth")
42
        self.tiny_vae_name = "taew2_1.pth"
helloyongyang's avatar
helloyongyang committed
43

44
45
    def load_transformer(self):
        model = WanModel(
46
            self.config["model_path"],
47
48
49
            self.config,
            self.init_device,
        )
50
        if self.config.get("lora_configs") and self.config.lora_configs:
51
            assert not self.config.get("dit_quantized", False)
52
            lora_wrapper = WanLoraWrapper(model)
53
54
55
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
56
                lora_name = lora_wrapper.load_lora(lora_path)
57
58
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
        return model

61
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
62
        image_encoder = None
63
        if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
64
65
66
67
68
            # offload config
            clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
            if clip_offload:
                clip_device = torch.device("cpu")
            else:
69
                clip_device = torch.device(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
70
71
72
73
74
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
75
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
gushiqiao's avatar
gushiqiao committed
76
                clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth"
77
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
78
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
79
80
81
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
82
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
83
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
84

85
86
            image_encoder = CLIPModel(
                dtype=torch.float16,
gushiqiao's avatar
gushiqiao committed
87
                device=clip_device,
88
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
89
90
91
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
gushiqiao's avatar
gushiqiao committed
92
                cpu_offload=clip_offload,
gushiqiao's avatar
gushiqiao committed
93
                use_31_block=self.config.get("use_31_block", True),
94
                load_from_rank0=self.config.get("load_from_rank0", False),
95
            )
96

97
        return image_encoder
helloyongyang's avatar
helloyongyang committed
98

99
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
100
        # offload config
101
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
102
103
104
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
105
            t5_device = torch.device(AI_DEVICE)
106
        tokenizer_path = os.path.join(self.config["model_path"], "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
107
108
109
110
111
        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
112
113
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
114
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
115
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
116
117
118
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
119
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
120
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
Fix  
gushiqiao committed
121

helloyongyang's avatar
helloyongyang committed
122
123
124
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
125
            device=t5_device,
126
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
127
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
128
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
129
            cpu_offload=t5_offload,
gushiqiao's avatar
gushiqiao committed
130
131
132
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
133
            load_from_rank0=self.config.get("load_from_rank0", False),
helloyongyang's avatar
helloyongyang committed
134
135
        )
        text_encoders = [text_encoder]
136
        return text_encoders
helloyongyang's avatar
helloyongyang committed
137

138
    def load_vae_encoder(self):
139
140
141
142
143
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
144
            vae_device = torch.device(AI_DEVICE)
145

146
        vae_config = {
gushiqiao's avatar
gushiqiao committed
147
            "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
148
            "device": vae_device,
149
            "parallel": self.config["parallel"],
150
            "use_tiling": self.config.get("use_tiling_vae", False),
151
            "cpu_offload": vae_offload,
152
            "dtype": GET_DTYPE(),
153
            "load_from_rank0": self.config.get("load_from_rank0", False),
gushiqiao's avatar
gushiqiao committed
154
            "use_lightvae": self.config.get("use_lightvae", False),
155
        }
156
        if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
157
158
            return None
        else:
159
            return self.vae_cls(**vae_config)
160
161

    def load_vae_decoder(self):
162
163
164
165
166
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
167
            vae_device = torch.device(AI_DEVICE)
168

169
        vae_config = {
gushiqiao's avatar
gushiqiao committed
170
            "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
171
            "device": vae_device,
172
            "parallel": self.config["parallel"],
173
            "use_tiling": self.config.get("use_tiling_vae", False),
174
            "cpu_offload": vae_offload,
gushiqiao's avatar
gushiqiao committed
175
            "use_lightvae": self.config.get("use_lightvae", False),
176
            "dtype": GET_DTYPE(),
177
            "load_from_rank0": self.config.get("load_from_rank0", False),
178
        }
gushiqiao's avatar
gushiqiao committed
179
        if self.config.get("use_tae", False):
gushiqiao's avatar
gushiqiao committed
180
            tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
181
            vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to(AI_DEVICE)
182
        else:
183
            vae_decoder = self.vae_cls(**vae_config)
184
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
185

186
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
187
        vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
188
        if vae_encoder is None or self.config.get("use_tae", False):
gushiqiao's avatar
gushiqiao committed
189
190
191
192
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
193
194

    def init_scheduler(self):
195
        if self.config["feature_caching"] == "NoCaching":
196
            scheduler_class = WanScheduler
197
        elif self.config["feature_caching"] == "TaylorSeer":
198
            scheduler_class = WanSchedulerTaylorCaching
Musisoul's avatar
Musisoul committed
199
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
200
201
202
203
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

204
        if self.config.get("changing_resolution", False):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
205
            self.scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
206
        else:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
207
            self.scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
208

yihuiwen's avatar
yihuiwen committed
209
210
211
212
213
214
    @ProfilingContext4DebugL1(
        "Run Text Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
        metrics_labels=["WanRunner"],
    )
215
    def run_text_encoder(self, input_info):
gushiqiao's avatar
gushiqiao committed
216
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
217
            self.text_encoders = self.load_text_encoder()
218
219

        prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
yihuiwen's avatar
yihuiwen committed
220
221
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(prompt))
222
        neg_prompt = input_info.negative_prompt
223

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
224
        if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]:
225
226
227
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
228
                context = self.text_encoders[0].infer([prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
229
                context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
230
231
                text_encoder_output = {"context": context}
            else:
232
                context_null = self.text_encoders[0].infer([neg_prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
233
                context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
234
235
                text_encoder_output = {"context_null": context_null}
        else:
236
            context = self.text_encoders[0].infer([prompt])
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
237
            context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
238
239
240
241
242
            if self.config.get("enable_cfg", False):
                context_null = self.text_encoders[0].infer([neg_prompt])
                context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
            else:
                context_null = None
243
244
245
246
247
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
248
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
249
250
251
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
252

helloyongyang's avatar
helloyongyang committed
253
254
        return text_encoder_output

yihuiwen's avatar
yihuiwen committed
255
256
257
258
259
260
    @ProfilingContext4DebugL1(
        "Run Image Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
        metrics_labels=["WanRunner"],
    )
gushiqiao's avatar
gushiqiao committed
261
    def run_image_encoder(self, first_frame, last_frame=None):
gushiqiao's avatar
gushiqiao committed
262
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
263
            self.image_encoder = self.load_image_encoder()
gushiqiao's avatar
gushiqiao committed
264
        if last_frame is None:
helloyongyang's avatar
helloyongyang committed
265
            clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
266
        else:
helloyongyang's avatar
helloyongyang committed
267
            clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
268
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
269
270
271
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
272
273
        return clip_encoder_out

Musisoul's avatar
Musisoul committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    def _adjust_latent_for_grid_splitting(self, latent_h, latent_w, world_size):
        """
        Adjust latent dimensions for optimal 2D grid splitting.
        Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
        """
        world_size_h, world_size_w = 1, 1
        if world_size <= 1:
            return latent_h, latent_w, world_size_h, world_size_w

        # Define priority grids for different world sizes
        priority_grids = []
        if world_size == 8:
            # For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
            priority_grids = [(2, 4), (4, 2), (1, 8), (8, 1)]
        elif world_size == 4:
            priority_grids = [(2, 2), (1, 4), (4, 1)]
        elif world_size == 2:
            priority_grids = [(1, 2), (2, 1)]
        else:
            # For other sizes, try factor pairs
            for h in range(1, int(np.sqrt(world_size)) + 1):
                if world_size % h == 0:
                    w = world_size // h
                    priority_grids.append((h, w))

        # Try priority grids first
        for world_size_h, world_size_w in priority_grids:
            if latent_h % world_size_h == 0 and latent_w % world_size_w == 0:
                return latent_h, latent_w, world_size_h, world_size_w

        # If no perfect fit, find minimal padding solution
        best_grid = (1, world_size)  # fallback
        min_total_padding = float("inf")

        for world_size_h, world_size_w in priority_grids:
            # Calculate required padding
            pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
            pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w
            total_padding = pad_h + pad_w

            # Prefer grids with minimal total padding
            if total_padding < min_total_padding:
                min_total_padding = total_padding
                best_grid = (world_size_h, world_size_w)

        # Apply padding
        world_size_h, world_size_w = best_grid
        pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
        pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w

        return latent_h + pad_h, latent_w + pad_w, world_size_h, world_size_w

yihuiwen's avatar
yihuiwen committed
326
327
328
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
329
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
yihuiwen's avatar
yihuiwen committed
330
331
        metrics_labels=["WanRunner"],
    )
gushiqiao's avatar
gushiqiao committed
332
    def run_vae_encoder(self, first_frame, last_frame=None):
helloyongyang's avatar
helloyongyang committed
333
        h, w = first_frame.shape[2:]
helloyongyang's avatar
helloyongyang committed
334
        aspect_ratio = h / w
335
        max_area = self.config["target_height"] * self.config["target_width"]
Musisoul's avatar
Musisoul committed
336
337
338
339
340
341
342
343
344
345
346
347
348

        # Calculate initial latent dimensions
        ori_latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
        ori_latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2])

        # Adjust latent dimensions for optimal 2D grid splitting when using distributed processing
        if dist.is_initialized() and dist.get_world_size() > 1:
            latent_h, latent_w, world_size_h, world_size_w = self._adjust_latent_for_grid_splitting(ori_latent_h, ori_latent_w, dist.get_world_size())
            logger.info(f"ori latent: {ori_latent_h}x{ori_latent_w}, adjust_latent: {latent_h}x{latent_w}, grid: {world_size_h}x{world_size_w}")
        else:
            latent_h, latent_w = ori_latent_h, ori_latent_w
            world_size_h, world_size_w = None, None

349
        latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)  # Important: latent_shape is used to set the input_info
350
351

        if self.config.get("changing_resolution", False):
gushiqiao's avatar
gushiqiao committed
352
            assert last_frame is None
353
354
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
355
356
357
                latent_h_tmp, latent_w_tmp = (
                    int(latent_h * self.config["resolution_rate"][i]) // 2 * 2,
                    int(latent_w * self.config["resolution_rate"][i]) // 2 * 2,
358
                )
Musisoul's avatar
Musisoul committed
359
360
                vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp, world_size_h=world_size_h, world_size_w=world_size_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w, world_size_h=world_size_h, world_size_w=world_size_w))
361
            return vae_encode_out_list, latent_shape
362
        else:
gushiqiao's avatar
gushiqiao committed
363
            if last_frame is not None:
helloyongyang's avatar
helloyongyang committed
364
365
                first_frame_size = first_frame.shape[2:]
                last_frame_size = last_frame.shape[2:]
gushiqiao's avatar
gushiqiao committed
366
367
368
369
370
371
372
                if first_frame_size != last_frame_size:
                    last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
                    last_frame_size = [
                        round(last_frame_size[0] * last_frame_resize_ratio),
                        round(last_frame_size[1] * last_frame_resize_ratio),
                    ]
                    last_frame = TF.center_crop(last_frame, last_frame_size)
Musisoul's avatar
Musisoul committed
373
            vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame, world_size_h=world_size_h, world_size_w=world_size_w)
374
            return vae_encoder_out, latent_shape
375

Musisoul's avatar
Musisoul committed
376
    def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None, world_size_h=None, world_size_w=None):
377
378
        h = lat_h * self.config["vae_stride"][1]
        w = lat_w * self.config["vae_stride"][2]
379
380
        msk = torch.ones(
            1,
381
            self.config["target_video_length"],
382
383
            lat_h,
            lat_w,
384
            device=torch.device(AI_DEVICE),
385
        )
gushiqiao's avatar
gushiqiao committed
386
387
388
389
390
        if last_frame is not None:
            msk[:, 1:-1] = 0
        else:
            msk[:, 1:] = 0

helloyongyang's avatar
helloyongyang committed
391
392
393
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
394

gushiqiao's avatar
gushiqiao committed
395
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
396
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
397
398
399
400

        if last_frame is not None:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
401
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
402
                    torch.zeros(3, self.config["target_video_length"] - 2, h, w),
helloyongyang's avatar
helloyongyang committed
403
                    torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
404
405
                ],
                dim=1,
406
            ).to(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
407
408
409
        else:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
410
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
411
                    torch.zeros(3, self.config["target_video_length"] - 1, h, w),
gushiqiao's avatar
gushiqiao committed
412
413
                ],
                dim=1,
414
            ).to(AI_DEVICE)
gushiqiao's avatar
gushiqiao committed
415

Musisoul's avatar
Musisoul committed
416
        vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()), world_size_h=world_size_h, world_size_w=world_size_w)
gushiqiao's avatar
gushiqiao committed
417

gushiqiao's avatar
gushiqiao committed
418
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
419
420
421
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
422
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
423
        return vae_encoder_out
424

gushiqiao's avatar
gushiqiao committed
425
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None):
426
427
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
428
            "vae_encoder_out": vae_encoder_out,
429
        }
430
431
432
433
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
434

435
436
437
438
439
440
441
442
443
    def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
        latent_shape = [
            self.config.get("num_channels_latents", 16),
            (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
            latent_h,
            latent_w,
        ]
        return latent_shape

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
444
    def get_latent_shape_with_target_hw(self):
445
446
447
        latent_shape = [
            self.config.get("num_channels_latents", 16),
            (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
448
449
            int(self.config["target_height"]) // self.config["vae_stride"][1],
            int(self.config["target_width"]) // self.config["vae_stride"][2],
450
451
        ]
        return latent_shape
helloyongyang's avatar
helloyongyang committed
452
453
454
455
456
457
458
459
460
461
462
463


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
464
465
466
467
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
468
469
470
471
472
473
474
475
476
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

477
    @ProfilingContext4DebugL2("Swtich models in infer_main costs")
helloyongyang's avatar
helloyongyang committed
478
479
480
    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
481
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
482
483
484
485
486
487
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=0)
                elif self.cur_model_index == 1:  # 1 -> 0
                    self.offload_cpu(model_index=1)
                    self.to_cuda(model_index=0)
helloyongyang's avatar
helloyongyang committed
488
489
490
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
491
            self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
492
493
494
495
496
497
            if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
                if self.cur_model_index == -1:
                    self.to_cuda(model_index=1)
                elif self.cur_model_index == 0:  # 0 -> 1
                    self.offload_cpu(model_index=0)
                    self.to_cuda(model_index=1)
helloyongyang's avatar
helloyongyang committed
498
499
500
501
502
503
504
505
506
507
508
509
510
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
        if not os.path.isdir(self.high_noise_model_path):
            self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
        elif self.config.get("high_noise_original_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_original_ckpt"]

        self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
        if not os.path.isdir(self.low_noise_model_path):
            self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
        if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
        elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_original_ckpt"]
helloyongyang's avatar
helloyongyang committed
526
527
528

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
529
        high_noise_model = WanModel(
530
            self.high_noise_model_path,
helloyongyang's avatar
helloyongyang committed
531
532
            self.config,
            self.init_device,
533
            model_type="wan2.2_moe_high_noise",
helloyongyang's avatar
helloyongyang committed
534
        )
helloyongyang's avatar
helloyongyang committed
535
        low_noise_model = WanModel(
536
            self.low_noise_model_path,
helloyongyang's avatar
helloyongyang committed
537
538
            self.config,
            self.init_device,
539
            model_type="wan2.2_moe_low_noise",
helloyongyang's avatar
helloyongyang committed
540
        )
541

542
        if self.config.get("lora_configs") and self.config["lora_configs"]:
543
            assert not self.config.get("dit_quantized", False)
544

545
            for lora_config in self.config["lora_configs"]:
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                base_name = os.path.basename(lora_path)
                if base_name.startswith("high"):
                    lora_wrapper = WanLoraWrapper(high_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
                elif base_name.startswith("low"):
                    lora_wrapper = WanLoraWrapper(low_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
                else:
                    raise ValueError(f"Unsupported LoRA path: {lora_path}")

562
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
563
564
565
566
567
568


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
569
        self.vae_encoder_need_img_original = True
570
571
572
573
        self.vae_cls = Wan2_2_VAE
        self.tiny_vae_cls = Wan2_2_VAE_tiny
        self.vae_name = "Wan2.2_VAE.pth"
        self.tiny_vae_name = "taew2_2.pth"
574

yihuiwen's avatar
yihuiwen committed
575
576
577
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
578
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
yihuiwen's avatar
yihuiwen committed
579
580
        metrics_labels=["Wan22DenseRunner"],
    )
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
597
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(AI_DEVICE).unsqueeze(1)
598
        vae_encoder_out = self.get_vae_encoder_output(img)
599
600
601
        latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
        latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
        return vae_encoder_out, latent_shape
602
603

    def get_vae_encoder_output(self, img):
604
        z = self.vae_encoder.encode(img.unsqueeze(0).to(GET_DTYPE()))
605
        return z