vae.py 5.81 KB
Newer Older
1
2
import json
import os
3
from typing import Optional
4

5
6
7
8
9
10
11
12
import torch

try:
    from diffusers import AutoencoderKLQwenImage
    from diffusers.image_processor import VaeImageProcessor
except ImportError:
    AutoencoderKLQwenImage = None
    VaeImageProcessor = None
13
14


15
16
17
18
19
20
21
22
23
24
25
26
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, "latents"):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents of provided encoder_output")


27
28
29
30
31
32
33
34
class AutoencoderKLQwenImageVAE:
    def __init__(self, config):
        self.config = config
        self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(config.model_path, "vae")).to(torch.device("cuda")).to(torch.bfloat16)
        self.image_processor = VaeImageProcessor(vae_scale_factor=config.vae_scale_factor * 2)
        with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f:
            vae_config = json.load(f)
            self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
35
            self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
36
        self.dtype = torch.bfloat16
37
38
        self.device = torch.device("cuda")
        self.latent_channels = config.vae_z_dim
39
40
41

    @staticmethod
    def _unpack_latents(latents, height, width, vae_scale_factor):
42
        batchsize, num_patches, channels = latents.shape
43
44
45
46
47
48

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        height = 2 * (int(height) // (vae_scale_factor * 2))
        width = 2 * (int(width) // (vae_scale_factor * 2))

49
        latents = latents.view(batchsize, height // 2, width // 2, channels // 4, 2, 2)
50
51
        latents = latents.permute(0, 3, 1, 4, 2, 5)

52
        latents = latents.reshape(batchsize, channels // (2 * 2), 1, height, width)
53
54
55
56
57

        return latents

    @torch.no_grad()
    def decode(self, latents):
58
59
60
61
        if self.config.task == "t2i":
            width, height = self.config.aspect_ratios[self.config.aspect_ratio]
        elif self.config.task == "i2i":
            width, height = self.config.auto_width, self.config.auto_hight
62
63
64
65
66
67
68
69
        latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor)
        latents = latents.to(self.dtype)
        latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
        latents_std = 1.0 / torch.tensor(self.config.vae_latents_std).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
        latents = latents / latents_std + latents_mean
        images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
        images = self.image_processor.postprocess(images, output_type="pil")
        return images
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    @staticmethod
    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
    def _pack_latents(latents, batchsize, num_channels_latents, height, width):
        latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
        latents = latents.permute(0, 2, 4, 1, 3, 5)
        latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
        return latents

    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
        if isinstance(generator, list):
            image_latents = [retrieve_latents(self.model.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0])]
            image_latents = torch.cat(image_latents, dim=0)
        else:
            image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax")
        latents_mean = torch.tensor(self.model.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
        latents_std = torch.tensor(self.model.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
        image_latents = (image_latents - latents_mean) / latents_std

        return image_latents

    def encode_vae_image(self, image):
        num_channels_latents = self.config.transformer_in_channels // 4
        image = image.to(self.device).to(self.dtype)
        if image.shape[1] != self.latent_channels:
            image_latents = self._encode_vae_image(image=image, generator=self.generator)
        else:
            image_latents = image
        if self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] == 0:
            # expand init_latents for batchsize
            additional_image_per_prompt = self.config.batchsize // image_latents.shape[0]
            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
        elif self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] != 0:
            raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config.batchsize} text prompts.")
        else:
            image_latents = torch.cat([image_latents], dim=0)

        image_latent_height, image_latent_width = image_latents.shape[3:]
        image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width)
        return image_latents