vae.py 6.77 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gc
import os

import torch

from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE

try:
    from diffusers import AutoencoderKLQwenImage
    from diffusers.image_processor import VaeImageProcessor
except ImportError:
    AutoencoderKLQwenImage = None
    VaeImageProcessor = None

torch_device_module = getattr(torch, AI_DEVICE)


class AutoencoderKLQwenImageVAE:
    def __init__(self, config):
        self.config = config
        self.is_layered = config.get("layered", False)
        if self.is_layered:
            self.layers = config.get("layers", 4)

        self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False))
        if self.cpu_offload:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device(AI_DEVICE)
        self.dtype = GET_DTYPE()
        self.latent_channels = 16
        self.vae_latents_mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
        self.vae_latents_std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916]
        self.load()

    def load(self):
        vae_path = self.config.get("vae_path", os.path.join(self.config["model_path"], "vae"))
        self.model = AutoencoderKLQwenImage.from_pretrained(vae_path).to(self.device).to(self.dtype)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
        if self.config.get("use_tiling_vae", False):
            self.model.enable_tiling()

    @staticmethod
    def _unpack_latents(latents, height, width, vae_scale_factor, layers=None):
        batchsize, num_patches, channels = latents.shape

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        height = 2 * (int(height) // (vae_scale_factor * 2))
        width = 2 * (int(width) // (vae_scale_factor * 2))
        if layers:
            latents = latents.view(batchsize, layers + 1, height // 2, width // 2, channels // 4, 2, 2)
            latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
            latents = latents.reshape(batchsize, layers + 1, channels // (2 * 2), height, width)
            latents = latents.permute(0, 2, 1, 3, 4)  # (b, c, f, h, w)
        else:
            latents = latents.view(batchsize, height // 2, width // 2, channels // 4, 2, 2)
            latents = latents.permute(0, 3, 1, 4, 2, 5)
            latents = latents.reshape(batchsize, channels // (2 * 2), 1, height, width)

        return latents

    @torch.no_grad()
    def decode(self, latents, input_info):
        if self.cpu_offload:
            self.model.to(AI_DEVICE)
        width, height = input_info.auto_width, input_info.auto_height
        if self.is_layered:
            latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"], self.layers)
        else:
            latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"])
        latents = latents.to(self.dtype)
        latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
        latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
        latents = latents / latents_std + latents_mean
        if self.is_layered:
            b, c, f, h, w = latents.shape
            latents = latents[:, :, 1:]  # remove the first frame as it is the orgin input
            latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
            image = self.model.decode(latents, return_dict=False)[0]  # (b f) c 1 h w
            image = image.squeeze(2)
            image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil")
            images = []
            for bidx in range(b):
                images.append(image[bidx * f : (bidx + 1) * f])
        else:
            images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
            images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch_device_module.empty_cache()
            gc.collect()
        return images

    @staticmethod
    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
    def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
        if not layers:
            latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
            latents = latents.permute(0, 2, 4, 1, 3, 5)
            latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
        else:
            latents = latents.permute(0, 2, 1, 3, 4)
            latents = latents.view(batchsize, layers, num_channels_latents, height // 2, 2, width // 2, 2)
            latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
            latents = latents.reshape(batchsize, layers * (height // 2) * (width // 2), num_channels_latents * 4)
        return latents

    def _encode_vae_image(self, image: torch.Tensor):
        image_latents = self.model.encode(image).latent_dist.mode()
        latents_mean = torch.tensor(self.model.config["latents_mean"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
        latents_std = torch.tensor(self.model.config["latents_std"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
        image_latents = (image_latents - latents_mean) / latents_std

        return image_latents

    @torch.no_grad()
    def encode_vae_image(self, image):
        if self.cpu_offload:
            self.model.to(AI_DEVICE)

        num_channels_latents = self.config["in_channels"] // 4
        image = image.to(self.model.device)

        if image.shape[1] != self.latent_channels:
            image_latents = self._encode_vae_image(image=image)
        else:
            image_latents = image
        image_latents = torch.cat([image_latents], dim=0)
        image_latent_height, image_latent_width = image_latents.shape[3:]
        if not self.is_layered:
            image_latents = self._pack_latents(image_latents, 1, num_channels_latents, image_latent_height, image_latent_width)
        else:
            image_latents = self._pack_latents(image_latents, 1, num_channels_latents, image_latent_height, image_latent_width, 1)

        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            gc.collect()
        return image_latents