wan_runner.py 15.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
2
import gc
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
gushiqiao's avatar
gushiqiao committed
10
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
11
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
12
)
13
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
14
    WanSchedulerCaching,
15
    WanSchedulerTaylorCaching,
16
)
helloyongyang's avatar
helloyongyang committed
17
from lightx2v.utils.profiler import ProfilingContext
18
from lightx2v.utils.utils import *
helloyongyang's avatar
helloyongyang committed
19
20
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
21
from lightx2v.models.networks.wan.model import WanModel, Wan22MoeModel
helloyongyang's avatar
helloyongyang committed
22
23
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
24
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
25
from lightx2v.utils.utils import cache_video
root's avatar
root committed
26
from loguru import logger
helloyongyang's avatar
helloyongyang committed
27
28
29
30
31
32
33


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

34
35
36
37
38
39
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
40
        if self.config.get("lora_configs") and self.config.lora_configs:
41
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
42
            lora_wrapper = WanLoraWrapper(model)
43
44
45
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
46
                lora_name = lora_wrapper.load_lora(lora_path)
47
48
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
49
50
        return model

51
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
52
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
53
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
54
55
56
57
58
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
59
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
60
61
62
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
63
64
65
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
66
67
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
gushiqiao's avatar
gushiqiao committed
68

69
70
            image_encoder = CLIPModel(
                dtype=torch.float16,
71
                device=self.init_device,
72
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
73
74
75
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
76
            )
77

78
        return image_encoder
helloyongyang's avatar
helloyongyang committed
79

80
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
81
        # offload config
gushiqiao's avatar
gushiqiao committed
82
83
84
85
86
        t5_offload = self.config.get("t5_cpu_offload", False)
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
87
88
89
90
91
92

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
93
94
95
96
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
97
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
98
99
100
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
101
102
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
gushiqiao's avatar
gushiqiao committed
103
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
104

helloyongyang's avatar
helloyongyang committed
105
106
107
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
108
            device=t5_device,
109
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
110
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
111
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
112
            cpu_offload=t5_offload,
113
            offload_granularity=self.config.get("t5_offload_granularity", "model"),
gushiqiao's avatar
gushiqiao committed
114
115
116
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
117
118
        )
        text_encoders = [text_encoder]
119
        return text_encoders
helloyongyang's avatar
helloyongyang committed
120

121
    def load_vae_encoder(self):
122
        vae_config = {
gushiqiao's avatar
gushiqiao committed
123
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
124
            "device": self.init_device,
125
126
127
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
128
129
130
131
132
133
134
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
        vae_config = {
gushiqiao's avatar
gushiqiao committed
135
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
136
137
138
139
            "device": self.init_device,
            "parallel": self.config.parallel_vae,
            "use_tiling": self.config.get("use_tiling_vae", False),
        }
helloyongyang's avatar
helloyongyang committed
140
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
141
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
142
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
143
                vae_pth=tiny_vae_path,
144
                device=self.init_device,
145
            ).to("cuda")
146
        else:
147
            vae_decoder = WanVAE(**vae_config)
148
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
149

150
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
151
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
152
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
153
154
155
156
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
157
158

    def init_scheduler(self):
159
160
161
162
163
164
165
166
167
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

168
        if self.config.get("changing_resolution", False):
169
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
170
        else:
171
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
172
173
        self.model.set_scheduler(scheduler)

174
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
175
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
176
            self.text_encoders = self.load_text_encoder()
helloyongyang's avatar
helloyongyang committed
177
        text_encoder_output = {}
178
179
180
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text])
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
gushiqiao's avatar
gushiqiao committed
181
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
182
183
184
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
185
186
187
188
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

189
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
190
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
191
            self.image_encoder = self.load_image_encoder()
192
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
wangshankun's avatar
wangshankun committed
193
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
194
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
195
196
197
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
198
199
200
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
201
202
203
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
204
205
206
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
207
208
209

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
210
211
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
212
213
214
215
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
216
217
218
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
219
220
221
222
223
224
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
            vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encode_out

    def get_vae_encoder_output(self, img, lat_h, lat_w):
225
226
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
227

228
229
230
231
232
233
234
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
235
236
237
238
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
239
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
240
            self.vae_encoder = self.load_vae_encoder()
241
        vae_encode_out = self.vae_encoder.encode(
242
243
244
245
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
246
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
247
248
249
250
                    ],
                    dim=1,
                ).cuda()
            ],
251
            self.config,
helloyongyang's avatar
helloyongyang committed
252
        )[0]
gushiqiao's avatar
gushiqiao committed
253
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
254
255
256
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
257
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
258
        return vae_encode_out
259
260

    def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
261
262
263
264
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
            "vae_encode_out": vae_encode_out,
        }
265
266
267
268
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
269
270

    def set_target_shape(self):
271
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
272
        if self.config.task == "i2v":
273
274
            self.config.target_shape = (
                num_channels_latents,
275
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
276
277
278
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
279
280
        elif self.config.task == "t2v":
            self.config.target_shape = (
281
                num_channels_latents,
282
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
283
284
285
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
286
287

    def save_video_func(self, images):
288
289
290
291
292
293
294
295
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
        high_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
        low_noise_model = Wan22MoeModel(
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)