vae.py 4.3 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gc
import os

import torch

from lightx2v.utils.envs import *

try:
    from diffusers import AutoencoderKL
    from diffusers.image_processor import VaeImageProcessor
except ImportError:
    AutoencoderKL = None
    VaeImageProcessor = None
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

ASPECT_RATIO_MAP = {
    "16:9": [1664, 928],
    "9:16": [928, 1664],
    "1:1": [1328, 1328],
    "4:3": [1472, 1140],
    "3:4": [768, 1024],
}


class AutoencoderKLZImageVAE:
    def __init__(self, config):
        self.config = config

        self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False))
        if self.cpu_offload:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device(AI_DEVICE)
        self.dtype = GET_DTYPE()
        self.latent_channels = 16
        self.vae_latents_mean = None
        self.vae_latents_std = None
        self.load()

    def load(self):
        self.model = AutoencoderKL.from_pretrained(os.path.join(self.config["model_path"], "vae")).to(self.device).to(GET_DTYPE())
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)

    @staticmethod
    def _unpack_latents(latents, latent_height, latent_width):
        batchsize, num_patches, channels = latents.shape
        num_channels_latents = channels // 4

        patch_height = latent_height // 2
        patch_width = latent_width // 2

        latents = latents.view(batchsize, patch_height, patch_width, num_channels_latents, 2, 2)
        latents = latents.permute(0, 3, 1, 4, 2, 5)
        latents = latents.reshape(batchsize, num_channels_latents, latent_height, latent_width)

        return latents

    @torch.no_grad()
    def decode(self, latents, input_info):
        if self.cpu_offload:
            self.model.to(torch.device(AI_DEVICE))

        latents = latents.to(next(self.model.parameters()).dtype)
        if hasattr(self.model.config, "scaling_factor") and hasattr(self.model.config, "shift_factor"):
            scaling_factor = self.model.config.scaling_factor
            shift_factor = self.model.config.shift_factor
            latents = (latents / scaling_factor) + shift_factor
        images = self.model.decode(latents, return_dict=False)[0]

        images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            gc.collect()
        return images

    @staticmethod
    def _pack_latents(latents, batchsize, num_channels_latents, height, width):
        latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
        latents = latents.permute(0, 2, 4, 1, 3, 5)  # (batch_size, height//2, width//2, num_channels, 2, 2)
        latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
        return latents

    def _encode_vae_image(self, image: torch.Tensor):
        encoder_output = self.model.encode(image)
        if hasattr(encoder_output, "latent_dist"):
            image_latents = encoder_output.latent_dist.mode()
        elif hasattr(encoder_output, "latents"):
            image_latents = encoder_output.latents
        else:
            raise AttributeError("Could not access latents from VAE encoder output")

        return image_latents

    @torch.no_grad()
    def encode_vae_image(self, image):
        if self.cpu_offload:
            self.model.to(torch.device(AI_DEVICE))

        image = image.to(self.model.device)

        if image.shape[1] != self.latent_channels:
            image_latents = self._encode_vae_image(image=image)
            # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
            if hasattr(self.model.config, "scaling_factor") and hasattr(self.model.config, "shift_factor"):
                image_latents = (image_latents - self.model.config.shift_factor) * self.model.config.scaling_factor
        else:
            image_latents = image
        image_latents = torch.cat([image_latents], dim=0)
        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            gc.collect()
        return image_latents