wan_runner.py 18.5 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
37
38
39
40
41
        device_mesh = self.config.get("device_mesh")
        if device_mesh is not None:
            self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
        else:
            self.seq_p_group = None
helloyongyang's avatar
helloyongyang committed
42

43
44
45
46
47
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
48
            self.seq_p_group,
49
        )
50
        if self.config.get("lora_configs") and self.config.lora_configs:
51
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
52
            lora_wrapper = WanLoraWrapper(model)
53
54
55
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
56
                lora_name = lora_wrapper.load_lora(lora_path)
57
58
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
        return model

61
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
62
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
63
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
64
65
66
67
68
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
69
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
70
71
72
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
73
74
75
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
76
77
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
78

79
80
            image_encoder = CLIPModel(
                dtype=torch.float16,
81
                device=self.init_device,
82
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
83
84
85
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
86
                seq_p_group=self.seq_p_group,
87
            )
88

89
        return image_encoder
helloyongyang's avatar
helloyongyang committed
90

91
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
92
        # offload config
gushiqiao's avatar
gushiqiao committed
93
94
95
96
97
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
98
99
100
101
102
103

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
104
105
106
107
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
108
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
109
110
111
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
112
113
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
114
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
115

helloyongyang's avatar
helloyongyang committed
116
117
118
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
119
            device=t5_device,
120
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
121
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
122
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
123
            cpu_offload=t5_offload,
124
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
125
126
127
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
128
            seq_p_group=self.seq_p_group,
helloyongyang's avatar
helloyongyang committed
129
130
        )
        text_encoders = [text_encoder]
131
        return text_encoders
helloyongyang's avatar
helloyongyang committed
132

133
    def load_vae_encoder(self):
134
        vae_config = {
gushiqiao's avatar
gushiqiao committed
135
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
136
            "device": self.init_device,
137
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
138
            "use_tiling": self.config.get("use_tiling_vae", False),
139
            "seq_p_group": self.seq_p_group,
140
        }
141
142
143
144
145
146
147
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
148
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
149
            "device": self.init_device,
150
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
151
152
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
153
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
154
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
155
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
156
                vae_pth=tiny_vae_path,
157
                device=self.init_device,
158
            ).to("cuda")
159
        else:
160
            vae_decoder = WanVAE(**vae_config)
161
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
162

163
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
164
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
165
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
166
167
168
169
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
170
171

    def init_scheduler(self):
172
173
174
175
176
177
178
179
180
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

181
        if self.config.get("changing_resolution", False):
182
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
183
        else:
184
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
185
186
        self.model.set_scheduler(scheduler)

187
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
188
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
189
            self.text_encoders = self.load_text_encoder()
190
        n_prompt = self.config.get("negative_prompt", "")
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
209
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
210
211
212
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
213

helloyongyang's avatar
helloyongyang committed
214
215
        return text_encoder_output

216
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
217
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
218
            self.image_encoder = self.load_image_encoder()
219
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
220
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
221
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
222
223
224
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
225
226
227
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
228
229
230
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
231
232
233
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
234
235
236

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
237
238
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
239
240
241
242
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
243
244
245
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
246
247
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
248
249
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
250
251

    def get_vae_encoder_output(self, img, lat_h, lat_w):
252
253
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
254

255
256
257
258
259
260
261
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
262
263
264
265
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
266
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
267
            self.vae_encoder = self.load_vae_encoder()
268
        vae_encoder_out = self.vae_encoder.encode(
269
270
271
272
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
273
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
274
275
276
277
                    ],
                    dim=1,
                ).cuda()
            ],
278
            self.config,
helloyongyang's avatar
helloyongyang committed
279
        )[0]
gushiqiao's avatar
gushiqiao committed
280
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
281
282
283
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
284
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
285
        return vae_encoder_out
286

287
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
288
289
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
290
            "vae_encoder_out": vae_encoder_out,
291
        }
292
293
294
295
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
296
297

    def set_target_shape(self):
298
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
299
        if self.config.task == "i2v":
300
301
            self.config.target_shape = (
                num_channels_latents,
302
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
303
304
305
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
306
307
        elif self.config.task == "t2v":
            self.config.target_shape = (
308
                num_channels_latents,
309
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
310
311
312
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
313
314

    def save_video_func(self, images):
315
316
317
318
319
320
321
322
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
323
324
325
326
327
328
329
330
331
332
333
334


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
335
336
337
338
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
382
        high_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
383
384
385
386
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
helloyongyang's avatar
helloyongyang committed
387
        low_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
388
389
390
391
392
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": self.init_device,
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
446
        z = self.vae_encoder.encode(img, self.config)
447
        return z