wan_runner.py 19.9 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
37
38
39
40
41
        device_mesh = self.config.get("device_mesh")
        if device_mesh is not None:
            self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
        else:
            self.seq_p_group = None
helloyongyang's avatar
helloyongyang committed
42

43
44
45
46
47
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
48
            self.seq_p_group,
49
        )
50
        if self.config.get("lora_configs") and self.config.lora_configs:
51
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
52
            lora_wrapper = WanLoraWrapper(model)
53
54
55
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
56
                lora_name = lora_wrapper.load_lora(lora_path)
57
58
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
        return model

61
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
62
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
63
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
64
65
66
67
68
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
69
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
70
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
71
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
72
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
73
74
75
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
76
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
77
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
78

79
80
            image_encoder = CLIPModel(
                dtype=torch.float16,
81
                device=self.init_device,
82
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
83
84
85
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
86
                seq_p_group=self.seq_p_group,
gushiqiao's avatar
gushiqiao committed
87
88
                cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
                use_31_block=self.config.get("use_31_block", True),
89
            )
90

91
        return image_encoder
helloyongyang's avatar
helloyongyang committed
92

93
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
94
        # offload config
95
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
96
97
98
99
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
100
101
102
103
104
105

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
106
107
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
108
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
109
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
110
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
111
112
113
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
114
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
115
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
gushiqiao committed
116
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
117

helloyongyang's avatar
helloyongyang committed
118
119
120
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
121
            device=t5_device,
122
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
123
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
124
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
125
            cpu_offload=t5_offload,
126
            offload_granularity=self.config.get("t5_offload_granularity", "model"),  # support ['model', 'block']
gushiqiao's avatar
gushiqiao committed
127
128
129
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
130
            seq_p_group=self.seq_p_group,
helloyongyang's avatar
helloyongyang committed
131
132
        )
        text_encoders = [text_encoder]
133
        return text_encoders
helloyongyang's avatar
helloyongyang committed
134

135
    def load_vae_encoder(self):
136
137
138
139
140
141
142
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

143
        vae_config = {
gushiqiao's avatar
gushiqiao committed
144
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
145
            "device": vae_device,
146
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
147
            "use_tiling": self.config.get("use_tiling_vae", False),
148
            "seq_p_group": self.seq_p_group,
149
            "cpu_offload": vae_offload,
150
        }
151
152
153
154
155
156
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
157
158
159
160
161
162
163
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

164
        vae_config = {
gushiqiao's avatar
gushiqiao committed
165
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
166
            "device": vae_device,
167
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
168
            "use_tiling": self.config.get("use_tiling_vae", False),
169
            "cpu_offload": vae_offload,
170
        }
helloyongyang's avatar
helloyongyang committed
171
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
172
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
173
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
174
                vae_pth=tiny_vae_path,
175
                device=self.init_device,
176
            ).to("cuda")
177
        else:
178
            vae_decoder = WanVAE(**vae_config)
179
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
180

181
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
182
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
183
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
184
185
186
187
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
188
189

    def init_scheduler(self):
190
191
192
193
194
195
196
197
198
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

199
        if self.config.get("changing_resolution", False):
200
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
201
        else:
202
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
203
204
        self.model.set_scheduler(scheduler)

205
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
206
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
207
            self.text_encoders = self.load_text_encoder()
208
        n_prompt = self.config.get("negative_prompt", "")
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
227
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
228
229
230
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
231

helloyongyang's avatar
helloyongyang committed
232
233
        return text_encoder_output

234
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
235
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
236
            self.image_encoder = self.load_image_encoder()
237
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
gushiqiao's avatar
gushiqiao committed
238
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
239
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
240
241
242
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
243
244
245
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
246
247
248
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
249
250
251
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
252
253
254

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
255
256
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
257
258
259
260
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
261
262
263
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
264
265
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
266
267
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
268
269

    def get_vae_encoder_output(self, img, lat_h, lat_w):
270
271
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
272

273
274
275
276
277
278
279
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
280
281
282
283
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
284
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
285
            self.vae_encoder = self.load_vae_encoder()
286
        vae_encoder_out = self.vae_encoder.encode(
287
288
289
290
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
291
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
292
293
294
295
                    ],
                    dim=1,
                ).cuda()
            ],
296
            self.config,
helloyongyang's avatar
helloyongyang committed
297
        )[0]
gushiqiao's avatar
gushiqiao committed
298
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
299
300
301
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
302
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
303
        return vae_encoder_out
304

305
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
306
307
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
308
            "vae_encoder_out": vae_encoder_out,
309
        }
310
311
312
313
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
314
315

    def set_target_shape(self):
316
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
317
        if self.config.task == "i2v":
318
319
            self.config.target_shape = (
                num_channels_latents,
320
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
321
322
323
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
324
325
        elif self.config.task == "t2v":
            self.config.target_shape = (
326
                num_channels_latents,
327
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
328
329
330
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
331
332

    def save_video_func(self, images):
333
334
335
336
337
338
339
340
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
341
342
343
344
345
346
347
348
349
350
351
352


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
353
354
355
356
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
400
        high_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
401
402
403
404
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
helloyongyang's avatar
helloyongyang committed
405
        low_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
406
407
408
409
410
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
411
412
413
414
415
416
417
418


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
419
420
421
422
423
424
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
425
426
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
427
428
429
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
430
431
432
433
434
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
435
436
437
438
439
440
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
441
442
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
443
444
445
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
480
        z = self.vae_encoder.encode(img, self.config)
481
        return z