wan_runner.py 9.16 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
10
11
12
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
    WanSchedulerTeaCaching,
)
helloyongyang's avatar
helloyongyang committed
13
14
15
16
17
18
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
19
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
20
from lightx2v.utils.utils import cache_video
root's avatar
root committed
21
from loguru import logger
helloyongyang's avatar
helloyongyang committed
22
23
24
25
26
27
28


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

29
30
31
32
33
34
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
35
        if self.config.lora_path:
36
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
37
38
39
40
41
42
            lora_wrapper = WanLoraWrapper(model)
            lora_name = lora_wrapper.load_lora(self.config.lora_path)
            lora_wrapper.apply_lora(lora_name, self.config.strength_model)
            logger.info(f"Loaded LoRA: {lora_name}")
        return model

43
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
44
        image_encoder = None
45
46
47
        if self.config.task == "i2v":
            image_encoder = CLIPModel(
                dtype=torch.float16,
48
                device=self.init_device,
49
50
51
52
                checkpoint_path=os.path.join(
                    self.config.model_path,
                    "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
                ),
53
54
55
                clip_quantized=self.config.get("clip_quantized", False),
                clip_quantized_ckpt=self.config.get("clip_quantized_ckpt", None),
                quant_scheme=self.config.get("clip_quant_scheme", None),
56
57
            )
        return image_encoder
helloyongyang's avatar
helloyongyang committed
58

59
    def load_text_encoder(self):
helloyongyang's avatar
helloyongyang committed
60
61
62
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
63
            device=self.init_device,
helloyongyang's avatar
helloyongyang committed
64
65
66
            checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
            tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
            shard_fn=None,
67
            cpu_offload=self.config.cpu_offload,
68
69
70
71
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
            t5_quantized=self.config.get("t5_quantized", False),
            t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None),
            quant_scheme=self.config.get("t5_quant_scheme", None),
helloyongyang's avatar
helloyongyang committed
72
73
        )
        text_encoders = [text_encoder]
74
        return text_encoders
helloyongyang's avatar
helloyongyang committed
75

76
    def load_vae_encoder(self):
77
78
        vae_config = {
            "vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
79
            "device": self.init_device,
80
81
82
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
83
84
85
86
87
88
89
90
91
92
93
94
95
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
            "vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
            "device": self.init_device,
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
        if self.config.get("tiny_vae", False):
96
            vae_decoder = WanVAE_tiny(
97
                vae_pth=self.config.tiny_vae_path,
98
                device=self.init_device,
99
            ).to("cuda")
100
        else:
101
            vae_decoder = WanVAE(**vae_config)
102
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
103

104
105
    def load_vae(self):
        return self.load_vae_encoder(), self.load_vae_decoder()
helloyongyang's avatar
helloyongyang committed
106
107
108
109
110
111
112
113
114
115

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = WanScheduler(self.config)
        elif self.config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(self.config)
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)

116
    def run_text_encoder(self, text, img):
117
118
        if self.config.get("lazy_load", False):
            self.text_encoders = self.load_text_encoder()
helloyongyang's avatar
helloyongyang committed
119
        text_encoder_output = {}
120
121
122
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text])
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
123
124
125
126
        if self.config.get("lazy_load", False):
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
127
128
129
130
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

131
    def run_image_encoder(self, img):
132
133
        if self.config.get("lazy_load", False):
            self.image_encoder = self.load_image_encoder()
134
135
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
136
137
138
139
        if self.config.get("lazy_load", False):
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
140
141
142
143
        return clip_encoder_out

    def run_vae_encoder(self, img):
        kwargs = {}
helloyongyang's avatar
helloyongyang committed
144
145
146
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
147
148
149
150
151
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
152

153
154
        self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
        self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
helloyongyang's avatar
helloyongyang committed
155

156
157
158
159
160
161
162
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
163
164
165
166
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
167
168
        if self.config.get("lazy_load", False):
            self.vae_encoder = self.load_vae_encoder()
169
        vae_encode_out = self.vae_encoder.encode(
170
171
172
173
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
174
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
175
176
177
178
                    ],
                    dim=1,
                ).cuda()
            ],
179
            self.config,
helloyongyang's avatar
helloyongyang committed
180
        )[0]
181
182
183
184
        if self.config.get("lazy_load", False):
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
185
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
186
187
188
        return vae_encode_out, kwargs

    def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
189
190
191
192
193
194
195
196
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
            "vae_encode_out": vae_encode_out,
        }
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
197
198

    def set_target_shape(self):
199
        ret = {}
200
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
201
        if self.config.task == "i2v":
202
203
            self.config.target_shape = (
                num_channels_latents,
204
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
205
206
207
                self.config.lat_h,
                self.config.lat_w,
            )
208
209
            ret["lat_h"] = self.config.lat_h
            ret["lat_w"] = self.config.lat_w
helloyongyang's avatar
helloyongyang committed
210
211
        elif self.config.task == "t2v":
            self.config.target_shape = (
212
                num_channels_latents,
213
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
214
215
216
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
217
218
219
220
        ret["target_shape"] = self.config.target_shape
        return ret

    def save_video_func(self, images):
221
222
223
224
225
226
227
228
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )