wan_runner.py 12.7 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
gushiqiao's avatar
gushiqiao committed
10
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
11
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
12
)
13
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
14
    WanSchedulerCaching,
15
    WanSchedulerTaylorCaching,
16
)
helloyongyang's avatar
helloyongyang committed
17
from lightx2v.utils.profiler import ProfilingContext
18
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
24
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
25
from lightx2v.utils.utils import cache_video
root's avatar
root committed
26
from loguru import logger
helloyongyang's avatar
helloyongyang committed
27
28
29
30
31
32
33


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

34
35
36
37
38
39
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
40
        if self.config.get("lora_configs") and self.config.lora_configs:
41
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
42
            lora_wrapper = WanLoraWrapper(model)
43
44
45
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
46
                lora_name = lora_wrapper.load_lora(lora_path)
47
48
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
49
50
        return model

51
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
52
        image_encoder = None
53
        if self.config.task == "i2v":
gushiqiao's avatar
gushiqiao committed
54
55
56
57
58
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
59
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
60
61
62
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
63
64
65
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
66
67
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
68

69
70
            image_encoder = CLIPModel(
                dtype=torch.float16,
71
                device=self.init_device,
72
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
73
74
75
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
76
            )
77

78
        return image_encoder
helloyongyang's avatar
helloyongyang committed
79

80
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
81
        # offload config
gushiqiao's avatar
gushiqiao committed
82
83
84
85
86
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
87
88
89
90
91
92

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
93
94
95
96
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
97
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
98
99
100
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
101
102
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
103
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
104

helloyongyang's avatar
helloyongyang committed
105
106
107
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
108
            device=t5_device,
109
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
110
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
111
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
112
            cpu_offload=t5_offload,
113
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
114
115
116
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
117
118
        )
        text_encoders = [text_encoder]
119
        return text_encoders
helloyongyang's avatar
helloyongyang committed
120

121
    def load_vae_encoder(self):
122
        vae_config = {
gushiqiao's avatar
gushiqiao committed
123
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
124
            "device": self.init_device,
125
126
127
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
128
129
130
131
132
133
134
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
135
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
136
137
138
139
            "device": self.init_device,
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
140
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
141
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
142
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
143
                vae_pth=tiny_vae_path,
144
                device=self.init_device,
145
            ).to("cuda")
146
        else:
147
            vae_decoder = WanVAE(**vae_config)
148
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
149

150
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
151
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
152
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
153
154
155
156
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
157
158

    def init_scheduler(self):
159
160
161
162
163
164
165
166
167
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

168
        if self.config.get("changing_resolution", False):
169
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
170
        else:
171
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
172
173
        self.model.set_scheduler(scheduler)

174
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
175
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
176
            self.text_encoders = self.load_text_encoder()
helloyongyang's avatar
helloyongyang committed
177
        text_encoder_output = {}
178
179
180
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text])
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
gushiqiao's avatar
gushiqiao committed
181
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
182
183
184
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
185
186
187
188
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

189
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
190
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
191
            self.image_encoder = self.load_image_encoder()
192
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
193
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
194
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
195
196
197
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
198
199
200
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
201
202
203
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
204
205
206
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
207
208
209

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
210
211
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
212
213
214
215
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
216
217
218
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
219
220
221
222
223
224
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
            vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encode_out

    def get_vae_encoder_output(self, img, lat_h, lat_w):
225
226
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
227

228
229
230
231
232
233
234
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
235
236
237
238
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
239
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
240
            self.vae_encoder = self.load_vae_encoder()
241
        vae_encode_out = self.vae_encoder.encode(
242
243
244
245
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
246
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
247
248
249
250
                    ],
                    dim=1,
                ).cuda()
            ],
251
            self.config,
helloyongyang's avatar
helloyongyang committed
252
        )[0]
gushiqiao's avatar
gushiqiao committed
253
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
254
255
256
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
257
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
258
        return vae_encode_out
259
260

    def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
261
262
263
264
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
            "vae_encode_out": vae_encode_out,
        }
265
266
267
268
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
269
270

    def set_target_shape(self):
271
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
272
        if self.config.task == "i2v":
273
274
            self.config.target_shape = (
                num_channels_latents,
275
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
276
277
278
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
279
280
        elif self.config.task == "t2v":
            self.config.target_shape = (
281
                num_channels_latents,
282
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
283
284
285
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
286
287

    def save_video_func(self, images):
288
289
290
291
292
293
294
295
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )