wan_runner.py 19.7 KB
Newer Older
1
import gc
PengGao's avatar
PengGao committed
2
3
import os

helloyongyang's avatar
helloyongyang committed
4
5
import numpy as np
import torch
6
import torch.distributed as dist
PengGao's avatar
PengGao committed
7
import torchvision.transforms.functional as TF
helloyongyang's avatar
helloyongyang committed
8
from PIL import Image
PengGao's avatar
PengGao committed
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.networks.wan.model import WanModel
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.models.runners.default_runner import DefaultRunner
gushiqiao's avatar
gushiqiao committed
16
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
17
    WanScheduler4ChangingResolutionInterface,
gushiqiao's avatar
gushiqiao committed
18
)
19
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
20
    WanSchedulerCaching,
21
    WanSchedulerTaylorCaching,
22
)
PengGao's avatar
PengGao committed
23
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
helloyongyang's avatar
helloyongyang committed
24
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
25
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
26
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
27
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
28
29
30
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36


@RUNNER_REGISTER("wan2.1")
class WanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)
37
38
39
40
41
        device_mesh = self.config.get("device_mesh")
        if device_mesh is not None:
            self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
        else:
            self.seq_p_group = None
helloyongyang's avatar
helloyongyang committed
42

43
44
45
46
47
    def load_transformer(self):
        model = WanModel(
            self.config.model_path,
            self.config,
            self.init_device,
48
            self.seq_p_group,
49
        )
50
        if self.config.get("lora_configs") and self.config.lora_configs:
51
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
52
            lora_wrapper = WanLoraWrapper(model)
53
54
55
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
GoatWu's avatar
GoatWu committed
56
                lora_name = lora_wrapper.load_lora(lora_path)
57
58
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
        return model

61
    def load_image_encoder(self):
helloyongyang's avatar
helloyongyang committed
62
        image_encoder = None
helloyongyang's avatar
helloyongyang committed
63
        if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
64
65
66
67
68
            # quant_config
            clip_quantized = self.config.get("clip_quantized", False)
            if clip_quantized:
                clip_quant_scheme = self.config.get("clip_quant_scheme", None)
                assert clip_quant_scheme is not None
gushiqiao's avatar
gushiqiao committed
69
                tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
70
                clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
71
                clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
72
                clip_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
73
74
75
            else:
                clip_quantized_ckpt = None
                clip_quant_scheme = None
76
                clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
77
                clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
gushiqiao's avatar
gushiqiao committed
78

79
80
            image_encoder = CLIPModel(
                dtype=torch.float16,
81
                device=self.init_device,
82
                checkpoint_path=clip_original_ckpt,
gushiqiao's avatar
gushiqiao committed
83
84
85
                clip_quantized=clip_quantized,
                clip_quantized_ckpt=clip_quantized_ckpt,
                quant_scheme=clip_quant_scheme,
86
                seq_p_group=self.seq_p_group,
87
            )
88

89
        return image_encoder
helloyongyang's avatar
helloyongyang committed
90

91
    def load_text_encoder(self):
gushiqiao's avatar
gushiqiao committed
92
        # offload config
93
        t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
gushiqiao's avatar
gushiqiao committed
94
95
96
97
        if t5_offload:
            t5_device = torch.device("cpu")
        else:
            t5_device = torch.device("cuda")
gushiqiao's avatar
gushiqiao committed
98
99
100
101
102
103

        # quant_config
        t5_quantized = self.config.get("t5_quantized", False)
        if t5_quantized:
            t5_quant_scheme = self.config.get("t5_quant_scheme", None)
            assert t5_quant_scheme is not None
104
105
            tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
            t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
106
            t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
107
            t5_original_ckpt = None
gushiqiao's avatar
gushiqiao committed
108
            tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
gushiqiao's avatar
gushiqiao committed
109
110
111
        else:
            t5_quant_scheme = None
            t5_quantized_ckpt = None
112
            t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
113
            t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
gushiqiao's avatar
gushiqiao committed
114
            tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
gushiqiao's avatar
Fix  
gushiqiao committed
115

helloyongyang's avatar
helloyongyang committed
116
117
118
        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
gushiqiao's avatar
gushiqiao committed
119
            device=t5_device,
120
            checkpoint_path=t5_original_ckpt,
gushiqiao's avatar
gushiqiao committed
121
            tokenizer_path=tokenizer_path,
helloyongyang's avatar
helloyongyang committed
122
            shard_fn=None,
gushiqiao's avatar
gushiqiao committed
123
            cpu_offload=t5_offload,
124
            offload_granularity=self.config.get("t5_offload_granularity", "model"),  # support ['model', 'block']
gushiqiao's avatar
gushiqiao committed
125
126
127
            t5_quantized=t5_quantized,
            t5_quantized_ckpt=t5_quantized_ckpt,
            quant_scheme=t5_quant_scheme,
128
            seq_p_group=self.seq_p_group,
helloyongyang's avatar
helloyongyang committed
129
130
        )
        text_encoders = [text_encoder]
131
        return text_encoders
helloyongyang's avatar
helloyongyang committed
132

133
    def load_vae_encoder(self):
134
135
136
137
138
139
140
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

141
        vae_config = {
gushiqiao's avatar
gushiqiao committed
142
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
143
            "device": vae_device,
144
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
145
            "use_tiling": self.config.get("use_tiling_vae", False),
146
            "seq_p_group": self.seq_p_group,
147
            "cpu_offload": vae_offload,
148
        }
149
150
151
152
153
154
        if self.config.task != "i2v":
            return None
        else:
            return WanVAE(**vae_config)

    def load_vae_decoder(self):
155
156
157
158
159
160
161
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")

162
        vae_config = {
gushiqiao's avatar
gushiqiao committed
163
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
164
            "device": vae_device,
165
            "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
166
            "use_tiling": self.config.get("use_tiling_vae", False),
167
            "cpu_offload": vae_offload,
168
        }
helloyongyang's avatar
helloyongyang committed
169
        if self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
170
            tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
171
            vae_decoder = WanVAE_tiny(
gushiqiao's avatar
gushiqiao committed
172
                vae_pth=tiny_vae_path,
173
                device=self.init_device,
174
            ).to("cuda")
175
        else:
176
            vae_decoder = WanVAE(**vae_config)
177
        return vae_decoder
helloyongyang's avatar
helloyongyang committed
178

179
    def load_vae(self):
gushiqiao's avatar
gushiqiao committed
180
        vae_encoder = self.load_vae_encoder()
helloyongyang's avatar
helloyongyang committed
181
        if vae_encoder is None or self.config.get("use_tiny_vae", False):
gushiqiao's avatar
gushiqiao committed
182
183
184
185
            vae_decoder = self.load_vae_decoder()
        else:
            vae_decoder = vae_encoder
        return vae_encoder, vae_decoder
helloyongyang's avatar
helloyongyang committed
186
187

    def init_scheduler(self):
188
189
190
191
192
193
194
195
196
        if self.config.feature_caching == "NoCaching":
            scheduler_class = WanScheduler
        elif self.config.feature_caching == "TaylorSeer":
            scheduler_class = WanSchedulerTaylorCaching
        elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
            scheduler_class = WanSchedulerCaching
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")

197
        if self.config.get("changing_resolution", False):
198
            scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
helloyongyang's avatar
helloyongyang committed
199
        else:
200
            scheduler = scheduler_class(self.config)
helloyongyang's avatar
helloyongyang committed
201
202
        self.model.set_scheduler(scheduler)

203
    def run_text_encoder(self, text, img):
gushiqiao's avatar
gushiqiao committed
204
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
205
            self.text_encoders = self.load_text_encoder()
206
        n_prompt = self.config.get("negative_prompt", "")
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

        if self.config["cfg_parallel"]:
            cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
            cfg_p_rank = dist.get_rank(cfg_p_group)
            if cfg_p_rank == 0:
                context = self.text_encoders[0].infer([text])
                text_encoder_output = {"context": context}
            else:
                context_null = self.text_encoders[0].infer([n_prompt])
                text_encoder_output = {"context_null": context_null}
        else:
            context = self.text_encoders[0].infer([text])
            context_null = self.text_encoders[0].infer([n_prompt])
            text_encoder_output = {
                "context": context,
                "context_null": context_null,
            }

gushiqiao's avatar
gushiqiao committed
225
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
226
227
228
            del self.text_encoders[0]
            torch.cuda.empty_cache()
            gc.collect()
229

helloyongyang's avatar
helloyongyang committed
230
231
        return text_encoder_output

232
    def run_image_encoder(self, img):
gushiqiao's avatar
gushiqiao committed
233
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
234
            self.image_encoder = self.load_image_encoder()
235
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
236
        clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
237
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
238
239
240
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
241
242
243
        return clip_encoder_out

    def run_vae_encoder(self, img):
helloyongyang's avatar
helloyongyang committed
244
245
246
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
        h, w = img.shape[1:]
        aspect_ratio = h / w
247
248
249
        max_area = self.config.target_height * self.config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
250
251
252

        if self.config.get("changing_resolution", False):
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
253
254
            vae_encode_out_list = []
            for i in range(len(self.config["resolution_rate"])):
255
256
257
258
                lat_h, lat_w = (
                    int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
                    int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
                )
259
260
261
                vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
            vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
            return vae_encode_out_list
262
263
        else:
            self.config.lat_h, self.config.lat_w = lat_h, lat_w
264
265
            vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
            return vae_encoder_out
266
267

    def get_vae_encoder_output(self, img, lat_h, lat_w):
268
269
        h = lat_h * self.config.vae_stride[1]
        w = lat_w * self.config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
270

271
272
273
274
275
276
277
        msk = torch.ones(
            1,
            self.config.target_video_length,
            lat_h,
            lat_w,
            device=torch.device("cuda"),
        )
helloyongyang's avatar
helloyongyang committed
278
279
280
281
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
282
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
283
            self.vae_encoder = self.load_vae_encoder()
284
        vae_encoder_out = self.vae_encoder.encode(
285
286
287
288
            [
                torch.concat(
                    [
                        torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
289
                        torch.zeros(3, self.config.target_video_length - 1, h, w),
290
291
292
293
                    ],
                    dim=1,
                ).cuda()
            ],
294
            self.config,
helloyongyang's avatar
helloyongyang committed
295
        )[0]
gushiqiao's avatar
gushiqiao committed
296
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
297
298
299
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
300
        vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
301
        return vae_encoder_out
302

303
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
304
305
        image_encoder_output = {
            "clip_encoder_out": clip_encoder_out,
306
            "vae_encoder_out": vae_encoder_out,
307
        }
308
309
310
311
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }
helloyongyang's avatar
helloyongyang committed
312
313

    def set_target_shape(self):
314
        num_channels_latents = self.config.get("num_channels_latents", 16)
helloyongyang's avatar
helloyongyang committed
315
        if self.config.task == "i2v":
316
317
            self.config.target_shape = (
                num_channels_latents,
318
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
319
320
321
                self.config.lat_h,
                self.config.lat_w,
            )
helloyongyang's avatar
helloyongyang committed
322
323
        elif self.config.task == "t2v":
            self.config.target_shape = (
324
                num_channels_latents,
325
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
helloyongyang's avatar
helloyongyang committed
326
327
328
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )
329
330

    def save_video_func(self, images):
331
332
333
334
335
336
337
338
        cache_video(
            tensor=images,
            save_file=self.config.save_video_path,
            fps=self.config.get("fps", 16),
            nrow=1,
            normalize=True,
            value_range=(-1, 1),
        )
helloyongyang's avatar
helloyongyang committed
339
340
341
342
343
344
345
346
347
348
349
350


class MultiModelStruct:
    def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000):
        self.model = model_list  # [high_noise_model, low_noise_model]
        assert len(self.model) == 2, "MultiModelStruct only supports 2 models now."
        self.config = config
        self.boundary = boundary
        self.boundary_timestep = self.boundary * num_train_timesteps
        self.cur_model_index = -1
        logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")

wangshankun's avatar
wangshankun committed
351
352
353
354
    @property
    def device(self):
        return self.model[self.cur_model_index].device

helloyongyang's avatar
helloyongyang committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def set_scheduler(self, shared_scheduler):
        self.scheduler = shared_scheduler
        for model in self.model:
            model.set_scheduler(shared_scheduler)

    def infer(self, inputs):
        self.get_current_model_index()
        self.model[self.cur_model_index].infer(inputs)

    def get_current_model_index(self):
        if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
            logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=0)
            elif self.cur_model_index == 1:  # 1 -> 0
                self.offload_cpu(model_index=1)
                self.to_cuda(model_index=0)
            self.cur_model_index = 0
        else:
            logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
            self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
            if self.cur_model_index == -1:
                self.to_cuda(model_index=1)
            elif self.cur_model_index == 0:  # 0 -> 1
                self.offload_cpu(model_index=0)
                self.to_cuda(model_index=1)
            self.cur_model_index = 1

    def offload_cpu(self, model_index):
        self.model[model_index].to_cpu()

    def to_cuda(self, model_index):
        self.model[model_index].to_cuda()


@RUNNER_REGISTER("wan2.2_moe")
class Wan22MoeRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
helloyongyang's avatar
helloyongyang committed
398
        high_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
399
400
401
402
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
helloyongyang's avatar
helloyongyang committed
403
        low_noise_model = WanModel(
helloyongyang's avatar
helloyongyang committed
404
405
406
407
408
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )
        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
409
410
411
412
413
414
415
416


@RUNNER_REGISTER("wan2.2")
class Wan22DenseRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
417
418
419
420
421
422
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
423
424
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
425
426
427
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
428
429
430
431
432
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
433
434
435
436
437
438
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
439
440
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
441
442
443
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

    def run_vae_encoder(self, img):
        max_area = self.config.target_height * self.config.target_width
        ih, iw = img.height, img.width
        dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2]
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
        vae_encoder_out = self.get_vae_encoder_output(img)
        self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1]

        return vae_encoder_out

    def get_vae_encoder_output(self, img):
478
        z = self.vae_encoder.encode(img, self.config)
479
        return z