wan_runner.py 19.8 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
27
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
37
38
39
40
        self.vae_cls = WanVAE
        self.tiny_vae_cls = WanVAE_tiny
        self.vae_name = "Wan2.1_VAE.pth"
        self.tiny_vae_name = "taew2_1.pth"
helloyongyang's avatar
helloyongyang committed
41

42
43
44
45
46
47
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
48
        if self.config.get("lora_configs") and self.config.lora_configs:
49
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
50
            lora_wrapper = WanLoraWrapper(model)
51
52
53
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
54
                lora_name = lora_wrapper.load_lora(lora_path)
55
56
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
57
58
        return model

59
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
60
        image_encoder = None
gushiqiao's avatar
gushiqiao committed
61
        if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
62
63
64
65
66
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
67
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
68
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
69
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
70
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
71
72
73
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
74
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
75
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
76

77
78
            image_encoder = CLIPModel(
                dtype=torch.float16,
79
                device=self.init_device,
80
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
81
82
83
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
gushiqiao's avatar
gushiqiao committed
84
85
                cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
                use_31_block=self.config.get("use_31_block", True),
86
            )
87

88
        return image_encoder
helloyongyang's avatar
helloyongyang committed
89

90
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
91
        # offload config
92
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
93
94
95
96
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
97
98
99
100
101
102

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
103
104
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
105
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
106
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
107
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
108
109
110
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
111
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
112
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
gushiqiao committed
113
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
114

helloyongyang's avatar
helloyongyang committed
115
116
117
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
118
            device=t5_device,
119
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
120
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
121
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
122
            cpu_offload=t5_offload,
123
            offload_granularity=self.config.get("t5_offload_granularity", "model"),  # support ['model', 'block']
gushiqiao's avatar
gushiqiao committed
124
125
126
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
helloyongyang's avatar
helloyongyang committed
127
128
        )
        text_encoders = [text_encoder]
129
        return text_encoders
helloyongyang's avatar
helloyongyang committed
130

131
    def load_vae_encoder(self):
132
133
134
135
136
137
138
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

139
        vae_config = {
140
            "vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
141
            "device": vae_device,
142
            "parallel": self.config.parallel,
143
            "use_tiling": self.config.get("use_tiling_vae", False),
144
            "cpu_offload": vae_offload,
145
            "dtype": GET_DTYPE(),
146
        }
gushiqiao's avatar
gushiqiao committed
147
        if self.config.task not in ["i2v", "flf2v", "vace"]:
148
149
            return None
        else:
150
            return self.vae_cls(**vae_config)
151
152

    def load_vae_decoder(self):
153
154
155
156
157
158
159
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

160
        vae_config = {
161
            "vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
162
            "device": vae_device,
163
            "parallel": self.config.parallel,
164
            "use_tiling": self.config.get("use_tiling_vae", False),
165
            "cpu_offload": vae_offload,
166
            "dtype": GET_DTYPE(),
167
        }
helloyongyang's avatar
helloyongyang committed
168
        if self.config.get("use_tiny_vae", False):
169
170
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", self.tiny_vae_name)
            vae_decoder = self.tiny_vae_cls(vae_pth=tiny_vae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
171
        else:
172
            vae_decoder = self.vae_cls(**vae_config)
173
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
174

175
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
176
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
177
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
178
179
180
181
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
182
183

    def init_scheduler(self):
184
185
186
187
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
Musisoul's avatar
Musisoul committed
188
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
189
190
191
192
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

193
        if self.config.get("changing_resolution", False):
194
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
195
        else:
196
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
197
198
        self.model.set_scheduler(scheduler)

gushiqiao's avatar
gushiqiao committed
199
    def run_text_encoder(self, text, img=None):
gushiqiao's avatar
gushiqiao committed
200
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
201
            self.text_encoders = self.load_text_encoder()
202
        n_prompt = self.config.get("negative_prompt", "")
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
221
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
222
223
224
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
225

helloyongyang's avatar
helloyongyang committed
226
227
        return text_encoder_output

gushiqiao's avatar
gushiqiao committed
228
    def run_image_encoder(self, first_frame, last_frame=None):
gushiqiao's avatar
gushiqiao committed
229
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
230
            self.image_encoder = self.load_image_encoder()
gushiqiao's avatar
gushiqiao committed
231
        if last_frame is None:
helloyongyang's avatar
helloyongyang committed
232
            clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
233
        else:
helloyongyang's avatar
helloyongyang committed
234
            clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
235
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
236
237
238
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
239
240
        return clip_encoder_out

gushiqiao's avatar
gushiqiao committed
241
    def run_vae_encoder(self, first_frame, last_frame=None):
helloyongyang's avatar
helloyongyang committed
242
        h, w = first_frame.shape[2:]
helloyongyang's avatar
helloyongyang committed
243
        aspect_ratio = h / w
244
245
246
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
247
248

        if self.config.get("changing_resolution", False):
gushiqiao's avatar
gushiqiao committed
249
            assert last_frame is None
250
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
251
252
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
253
254
255
256
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
gushiqiao's avatar
gushiqiao committed
257
258
                vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, self.config.lat_h, self.config.lat_w))
259
            return vae_encode_out_list
260
        else:
gushiqiao's avatar
gushiqiao committed
261
            if last_frame is not None:
helloyongyang's avatar
helloyongyang committed
262
263
                first_frame_size = first_frame.shape[2:]
                last_frame_size = last_frame.shape[2:]
gushiqiao's avatar
gushiqiao committed
264
265
266
267
268
269
270
                if first_frame_size != last_frame_size:
                    last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
                    last_frame_size = [
                        round(last_frame_size[0] * last_frame_resize_ratio),
                        round(last_frame_size[1] * last_frame_resize_ratio),
                    ]
                    last_frame = TF.center_crop(last_frame, last_frame_size)
271
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
gushiqiao's avatar
gushiqiao committed
272
            vae_encoder_out = self.get_vae_encoder_output(first_frame, lat_h, lat_w, last_frame)
273
            return vae_encoder_out
274

gushiqiao's avatar
gushiqiao committed
275
    def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
276
277
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
278
279
280
281
282
283
284
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
gushiqiao's avatar
gushiqiao committed
285
286
287
288
289
        if last_frame is not None:
            msk[:, 1:-1] = 0
        else:
            msk[:, 1:] = 0

helloyongyang's avatar
helloyongyang committed
290
291
292
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
293

gushiqiao's avatar
gushiqiao committed
294
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
295
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
296
297
298
299

        if last_frame is not None:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
300
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
301
                    torch.zeros(3, self.config.target_video_length - 2, h, w),
helloyongyang's avatar
helloyongyang committed
302
                    torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
303
304
305
306
307
308
                ],
                dim=1,
            ).cuda()
        else:
            vae_input = torch.concat(
                [
helloyongyang's avatar
helloyongyang committed
309
                    torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
gushiqiao's avatar
gushiqiao committed
310
311
312
313
314
                    torch.zeros(3, self.config.target_video_length - 1, h, w),
                ],
                dim=1,
            ).cuda()

315
        vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
316

gushiqiao's avatar
gushiqiao committed
317
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
318
319
320
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
321
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
322
        return vae_encoder_out
323

gushiqiao's avatar
gushiqiao committed
324
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None):
325
326
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
327
            "vae_encoder_out": vae_encoder_out,
328
        }
329
330
331
332
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
333
334

    def set_target_shape(self):
335
        num_channels_latents = self.config.get("num_channels_latents", 16)
gushiqiao's avatar
gushiqiao committed
336
        if self.config.task in ["i2v", "flf2v"]:
337
338
            self.config.target_shape = (
                num_channels_latents,
339
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
340
341
342
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
343
344
        elif self.config.task == "t2v":
            self.config.target_shape = (
345
                num_channels_latents,
346
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
347
348
349
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
350
351

    def save_video_func(self, images):
352
353
354
355
356
357
358
359
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
360
361
362
363
364
365
366
367
368
369
370
371


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
372
373
374
375
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
419
        high_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
420
421
422
423
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
helloyongyang's avatar
helloyongyang committed
424
        low_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
425
426
427
428
429
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
430
431
432
433
434
435


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
436
        self.vae_encoder_need_img_original = True
437
438
439
440
        self.vae_cls = Wan2_2_VAE
        self.tiny_vae_cls = Wan2_2_VAE_tiny
        self.vae_name = "Wan2.2_VAE.pth"
        self.tiny_vae_name = "taew2_2.pth"
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
465
        z = self.vae_encoder.encode(img.to(GET_DTYPE()))
466
        return z