vae.py 1.15 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os

import torch
from PIL import Image

from .autoencoder import load_ae


class BagelVae:
    def __init__(self, config):
        self.config = config
        vae_path = os.path.join(config["model_path"], "ae.safetensors")
        self.vae_model, self.vae_params = load_ae(vae_path)
        self.vae_model = self.vae_model

    def decode(self, latents, decode_info):
        latents = latents.split((decode_info["packed_seqlens"] - 2).tolist())

        H, W = decode_info["image_shape"]
        h, w = H // decode_info["latent_downsample"], W // decode_info["latent_downsample"]

        latents = latents[0]
        latents = latents.reshape(1, h, w, decode_info["latent_patch_size"], decode_info["latent_patch_size"], decode_info["latent_channel"])
        latents = torch.einsum("nhwpqc->nchpwq", latents)
        latents = latents.reshape(1, decode_info["latent_channel"], h * decode_info["latent_patch_size"], w * decode_info["latent_patch_size"])

        image = self.vae_model.decode(latents)
        image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
        image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
        return [image]