wan_runner.py 20.9 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36
37


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

38
39
40
41
42
43
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
44
        if self.config.get("lora_configs") and self.config.lora_configs:
45
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
46
            lora_wrapper = WanLoraWrapper(model)
47
48
49
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
50
                lora_name = lora_wrapper.load_lora(lora_path)
51
52
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
53
54
        return model

55
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
56
        image_encoder = None
gushiqiao's avatar
gushiqiao committed
57
        if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
58
59
60
61
62
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
63
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
64
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
65
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
66
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
67
68
69
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
70
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
71
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
72

73
74
            image_encoder = CLIPModel(
                dtype=torch.float16,
75
                device=self.init_device,
76
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
77
78
79
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
gushiqiao's avatar
gushiqiao committed
80
81
                cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
                use_31_block=self.config.get("use_31_block", True),
82
            )
83

84
        return image_encoder
helloyongyang's avatar
helloyongyang committed
85

86
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
87
        # offload config
88
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
89
90
91
92
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
93
94
95
96
97
98

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
99
100
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
101
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
102
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
103
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
104
105
106
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
107
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
108
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
gushiqiao committed
109
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
110

helloyongyang's avatar
helloyongyang committed
111
112
113
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
114
            device=t5_device,
115
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
116
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
117
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
118
            cpu_offload=t5_offload,
119
            offload_granularity=self.config.get("t5_offload_granularity", "model"),  # support ['model', 'block']
gushiqiao's avatar
gushiqiao committed
120
121
122
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
123
124
        )
        text_encoders = [text_encoder]
125
        return text_encoders
helloyongyang's avatar
helloyongyang committed
126

127
    def load_vae_encoder(self):
128
129
130
131
132
133
134
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

135
        vae_config = {
gushiqiao's avatar
gushiqiao committed
136
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
137
            "device": vae_device,
138
            "parallel": self.config.parallel,
139
            "use_tiling": self.config.get("use_tiling_vae", False),
140
            "cpu_offload": vae_offload,
141
            "dtype": GET_DTYPE(),
142
        }
gushiqiao's avatar
gushiqiao committed
143
        if self.config.task not in ["i2v", "flf2v", "vace"]:
144
145
146
147
148
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
149
150
151
152
153
154
155
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

156
        vae_config = {
gushiqiao's avatar
gushiqiao committed
157
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
158
            "device": vae_device,
159
            "parallel": self.config.parallel,
160
            "use_tiling": self.config.get("use_tiling_vae", False),
161
            "cpu_offload": vae_offload,
162
            "dtype": GET_DTYPE(),
163
        }
helloyongyang's avatar
helloyongyang committed
164
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
165
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
gushiqiao's avatar
gushiqiao committed
166
            vae_decoder = WanVAE_tiny(vae_pth=tiny_vae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
167
        else:
168
            vae_decoder = WanVAE(**vae_config)
169
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
170

171
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
172
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
173
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
174
175
176
177
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
178
179

    def init_scheduler(self):
180
181
182
183
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
Musisoul's avatar
Musisoul committed
184
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
185
186
187
188
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

189
        if self.config.get("changing_resolution", False):
190
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
191
        else:
192
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
193
194
        self.model.set_scheduler(scheduler)

gushiqiao's avatar
gushiqiao committed
195
    def run_text_encoder(self, text, img=None):
gushiqiao's avatar
gushiqiao committed
196
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
197
            self.text_encoders = self.load_text_encoder()
198
        n_prompt = self.config.get("negative_prompt", "")
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
217
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
218
219
220
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
221

helloyongyang's avatar
helloyongyang committed
222
223
        return text_encoder_output

gushiqiao's avatar
gushiqiao committed
224
    def run_image_encoder(self, first_frame, last_frame=None):
gushiqiao's avatar
gushiqiao committed
225
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
226
            self.image_encoder = self.load_image_encoder()
gushiqiao's avatar
gushiqiao committed
227
        if last_frame is None:
helloyongyang's avatar
helloyongyang committed
228
            clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
229
        else:
helloyongyang's avatar
helloyongyang committed
230
            clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
231
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
232
233
234
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
235
236
        return clip_encoder_out

gushiqiao's avatar
gushiqiao committed
237
    def run_vae_encoder(self, first_frame, last_frame=None):
helloyongyang's avatar
helloyongyang committed
238
        h, w = first_frame.shape[2:]
helloyongyang's avatar
helloyongyang committed
239
        aspect_ratio = h / w
240
241
242
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
243
244

        if self.config.get("changing_resolution", False):
gushiqiao's avatar
gushiqiao committed
245
            assert last_frame is None
246
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
247
248
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
249
250
251
252
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
gushiqiao's avatar
gushiqiao committed
253
254
                vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, self.config.lat_h, self.config.lat_w))
255
            return vae_encode_out_list
256
        else:
gushiqiao's avatar
gushiqiao committed
257
            if last_frame is not None:
helloyongyang's avatar
helloyongyang committed
258
259
                first_frame_size = first_frame.shape[2:]
                last_frame_size = last_frame.shape[2:]
gushiqiao's avatar
gushiqiao committed
260
261
262
263
264
265
266
                if first_frame_size != last_frame_size:
                    last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
                    last_frame_size = [
                        round(last_frame_size[0] * last_frame_resize_ratio),
                        round(last_frame_size[1] * last_frame_resize_ratio),
                    ]
                    last_frame = TF.center_crop(last_frame, last_frame_size)
267
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
gushiqiao's avatar
gushiqiao committed
268
            vae_encoder_out = self.get_vae_encoder_output(first_frame, lat_h, lat_w, last_frame)
269
            return vae_encoder_out
270

gushiqiao's avatar
gushiqiao committed
271
    def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
272
273
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
274
275
276
277
278
279
280
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
gushiqiao's avatar
gushiqiao committed
281
282
283
284
285
        if last_frame is not None:
            msk[:, 1:-1] = 0
        else:
            msk[:, 1:] = 0

helloyongyang's avatar
helloyongyang committed
286
287
288
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
289

gushiqiao's avatar
gushiqiao committed
290
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
291
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
292
293
294
295

        if last_frame is not None:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
296
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
297
                    torch.zeros(3, self.config.target_video_length - 2, h, w),
helloyongyang's avatar
helloyongyang committed
298
                    torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
299
300
301
302
303
304
                ],
                dim=1,
            ).cuda()
        else:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
305
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
306
307
308
309
310
                    torch.zeros(3, self.config.target_video_length - 1, h, w),
                ],
                dim=1,
            ).cuda()

311
        vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
312

gushiqiao's avatar
gushiqiao committed
313
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
314
315
316
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
317
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
318
        return vae_encoder_out
319

gushiqiao's avatar
gushiqiao committed
320
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None):
321
322
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
323
            "vae_encoder_out": vae_encoder_out,
324
        }
325
326
327
328
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
329
330

    def set_target_shape(self):
331
        num_channels_latents = self.config.get("num_channels_latents", 16)
gushiqiao's avatar
gushiqiao committed
332
        if self.config.task in ["i2v", "flf2v"]:
333
334
            self.config.target_shape = (
                num_channels_latents,
335
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
336
337
338
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
339
340
        elif self.config.task == "t2v":
            self.config.target_shape = (
341
                num_channels_latents,
342
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
343
344
345
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
346
347

    def save_video_func(self, images):
348
349
350
351
352
353
354
355
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
356
357
358
359
360
361
362
363
364
365
366
367


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
368
369
370
371
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
415
        high_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
416
417
418
419
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
helloyongyang's avatar
helloyongyang committed
420
        low_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
421
422
423
424
425
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
426
427
428
429
430
431


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
432
        self.vae_encoder_need_img_original = True
433
434

    def load_vae_decoder(self):
435
436
437
438
439
440
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
441
442
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
443
444
445
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
446
            "dtype": GET_DTYPE(),
447
448
449
450
451
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
452
453
454
455
456
457
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
458
459
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
460
461
462
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
463
            "dtype": GET_DTYPE(),
464
        }
gushiqiao's avatar
gushiqiao committed
465
        if self.config.task not in ["i2v", "flf2v"]:
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
498
        z = self.vae_encoder.encode(img.to(GET_DTYPE()))
499
        return z