wan_runner.py 7.72 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
import os
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
9
10
11
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
    WanSchedulerTeaCaching,
)
helloyongyang's avatar
helloyongyang committed
12
13
14
15
16
17
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
18
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
helloyongyang's avatar
helloyongyang committed
19
import torch.distributed as dist
root's avatar
root committed
20
from loguru import logger
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

28
29
30
31
32
33
34
35
36
37
38
39
40
    def load_transformer(self):
        if self.config.cpu_offload:
            init_device = torch.device("cpu")
        else:
            init_device = torch.device("cuda")
        model = WanModel(self.config.model_path, self.config, init_device)
        if self.config.lora_path:
            lora_wrapper = WanLoraWrapper(model)
            lora_name = lora_wrapper.load_lora(self.config.lora_path)
            lora_wrapper.apply_lora(lora_name, self.config.strength_model)
            logger.info(f"Loaded LoRA: {lora_name}")
        return model

helloyongyang's avatar
helloyongyang committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    @ProfilingContext("Load models")
    def load_model(self):
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        image_encoder = None
        if self.config.cpu_offload:
            init_device = torch.device("cpu")
        else:
            init_device = torch.device("cuda")

        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
            device=init_device,
            checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
            tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
            shard_fn=None,
59
60
            cpu_offload=self.config.cpu_offload,
            offload_granularity=self.config.get("text_encoder_offload_granularity", "model"),
helloyongyang's avatar
helloyongyang committed
61
62
63
64
65
66
67
68
        )
        text_encoders = [text_encoder]
        model = WanModel(self.config.model_path, self.config, init_device)

        if self.config.lora_path:
            lora_wrapper = WanLoraWrapper(model)
            lora_name = lora_wrapper.load_lora(self.config.lora_path)
            lora_wrapper.apply_lora(lora_name, self.config.strength_model)
root's avatar
root committed
69
            logger.info(f"Loaded LoRA: {lora_name}")
helloyongyang's avatar
helloyongyang committed
70

71
72
73
74
75
76
77
78
79
80
81
82
83
        if self.config.get("tiny_vae", False):
            vae_model = WanVAE_tiny(
                vae_pth=self.config.tiny_vae_path,
                device=init_device,
            )
            vae_model = vae_model.to("cuda")
        else:
            vae_model = WanVAE(
                vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
                device=init_device,
                parallel=self.config.parallel_vae,
                use_tiling=self.config.get("use_tiling_vae", False),
            )
helloyongyang's avatar
helloyongyang committed
84
85
86
87
        if self.config.task == "i2v":
            image_encoder = CLIPModel(
                dtype=torch.float16,
                device=init_device,
88
89
90
91
                checkpoint_path=os.path.join(
                    self.config.model_path,
                    "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
                ),
helloyongyang's avatar
helloyongyang committed
92
93
                tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
            )
94
95
96
97
98
99
100
101
            if self.config.get("tiny_vae", False):
                org_vae = WanVAE(
                    vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
                    device=init_device,
                    parallel=self.config.parallel_vae,
                    use_tiling=self.config.get("use_tiling_vae", False),
                )
                image_encoder = [image_encoder, org_vae]
helloyongyang's avatar
helloyongyang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        return model, text_encoders, vae_model, image_encoder

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanScheduler(self.config)
        elif self.config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)

    def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
        text_encoder_output = {}
        n_prompt = config.get("negative_prompt", "")
117
118
        context = text_encoders[0].infer([text])
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""])
helloyongyang's avatar
helloyongyang committed
119
120
121
122
123
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

    def run_image_encoder(self, config, image_encoder, vae_model):
124
125
126
127
        if self.config.get("tiny_vae", False):
            clip_image_encoder, vae_image_encoder = image_encoder[0], image_encoder[1]
        else:
            clip_image_encoder, vae_image_encoder = image_encoder, vae_model
helloyongyang's avatar
helloyongyang committed
128
129
        img = Image.open(config.image_path).convert("RGB")
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
130
        clip_encoder_out = clip_image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
131
132
133
134
135
136
137
138
139
140
141
        h, w = img.shape[1:]
        aspect_ratio = h / w
        max_area = config.target_height * config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
        h = lat_h * config.vae_stride[1]
        w = lat_w * config.vae_stride[2]

        config.lat_h = lat_h
        config.lat_w = lat_w

142
        msk = torch.ones(1, config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
143
144
145
146
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
147
        vae_encode_out = vae_image_encoder.encode(
148
149
150
151
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
152
                        torch.zeros(3, config.target_video_length - 1, h, w),
153
154
155
156
157
                    ],
                    dim=1,
                ).cuda()
            ],
            config,
helloyongyang's avatar
helloyongyang committed
158
159
160
161
162
        )[0]
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
        return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}

    def set_target_shape(self):
163
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
164
        if self.config.task == "i2v":
165
166
            self.config.target_shape = (
                num_channels_latents,
167
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
168
169
170
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
171
172
        elif self.config.task == "t2v":
            self.config.target_shape = (
173
                num_channels_latents,
174
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
175
176
177
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )