wan_runner.py 6.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
import os
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
9
10
11
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
    WanSchedulerTeaCaching,
)
helloyongyang's avatar
helloyongyang committed
12
13
14
15
16
17
18
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist
19
from lightx2v.utils.memory_profiler import peak_memory_decorator
root's avatar
root committed
20
from loguru import logger
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

    @ProfilingContext("Load models")
    def load_model(self):
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        image_encoder = None
        if self.config.cpu_offload:
            init_device = torch.device("cpu")
        else:
            init_device = torch.device("cuda")

        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
            device=init_device,
            checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
            tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
            shard_fn=None,
        )
        text_encoders = [text_encoder]
        model = WanModel(self.config.model_path, self.config, init_device)

        if self.config.lora_path:
            lora_wrapper = WanLoraWrapper(model)
            lora_name = lora_wrapper.load_lora(self.config.lora_path)
            lora_wrapper.apply_lora(lora_name, self.config.strength_model)
root's avatar
root committed
54
            logger.info(f"Loaded LoRA: {lora_name}")
helloyongyang's avatar
helloyongyang committed
55

56
57
58
59
60
        vae_model = WanVAE(
            vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
            device=init_device,
            parallel=self.config.parallel_vae,
        )
helloyongyang's avatar
helloyongyang committed
61
62
63
64
        if self.config.task == "i2v":
            image_encoder = CLIPModel(
                dtype=torch.float16,
                device=init_device,
65
66
67
68
                checkpoint_path=os.path.join(
                    self.config.model_path,
                    "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
                ),
helloyongyang's avatar
helloyongyang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
            )

        return model, text_encoders, vae_model, image_encoder

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanScheduler(self.config)
        elif self.config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)

83
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
84
85
86
87
88
89
90
91
92
    def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
        text_encoder_output = {}
        n_prompt = config.get("negative_prompt", "")
        context = text_encoders[0].infer([text], config)
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

93
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def run_image_encoder(self, config, image_encoder, vae_model):
        img = Image.open(config.image_path).convert("RGB")
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
        h, w = img.shape[1:]
        aspect_ratio = h / w
        max_area = config.target_height * config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
        h = lat_h * config.vae_stride[1]
        w = lat_w * config.vae_stride[2]

        config.lat_h = lat_h
        config.lat_w = lat_w

109
        msk = torch.ones(1, config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
110
111
112
113
114
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
        vae_encode_out = vae_model.encode(
115
116
117
118
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
119
                        torch.zeros(3, config.target_video_length - 1, h, w),
120
121
122
123
124
                    ],
                    dim=1,
                ).cuda()
            ],
            config,
helloyongyang's avatar
helloyongyang committed
125
126
127
128
129
        )[0]
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
        return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}

    def set_target_shape(self):
130
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
131
        if self.config.task == "i2v":
132
133
134
135
136
137
            self.config.target_shape = (
                num_channels_latents,
                (self.config.target_video_length - 1) // 4 + 1,
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
138
139
        elif self.config.task == "t2v":
            self.config.target_shape = (
140
                num_channels_latents,
helloyongyang's avatar
helloyongyang committed
141
142
143
144
                (self.config.target_video_length - 1) // 4 + 1,
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )