vae.py 6.71 KB
Newer Older
Watebear's avatar
Watebear committed
1
import gc
2
3
import json
import os
4
from typing import Optional
5

6
7
8
9
10
11
12
13
import torch

try:
    from diffusers import AutoencoderKLQwenImage
    from diffusers.image_processor import VaeImageProcessor
except ImportError:
    AutoencoderKLQwenImage = None
    VaeImageProcessor = None
14
15


16
17
18
19
20
21
22
23
24
25
26
27
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, "latents"):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents of provided encoder_output")


28
29
30
class AutoencoderKLQwenImageVAE:
    def __init__(self, config):
        self.config = config
Watebear's avatar
Watebear committed
31
32
33
34
35
36

        self.cpu_offload = config.get("cpu_offload", False)
        if self.cpu_offload:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device("cuda")
37
        self.dtype = torch.bfloat16
38
        self.latent_channels = config["vae_z_dim"]
Watebear's avatar
Watebear committed
39
40
41
        self.load()

    def load(self):
42
43
44
        self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config["model_path"], "vae")).to(self.device).to(self.dtype)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
        with open(os.path.join(self.config["model_path"], "vae", "config.json"), "r") as f:
Watebear's avatar
Watebear committed
45
46
            vae_config = json.load(f)
            self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
47
48
49

    @staticmethod
    def _unpack_latents(latents, height, width, vae_scale_factor):
50
        batchsize, num_patches, channels = latents.shape
51
52
53
54
55
56

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        height = 2 * (int(height) // (vae_scale_factor * 2))
        width = 2 * (int(width) // (vae_scale_factor * 2))

57
        latents = latents.view(batchsize, height // 2, width // 2, channels // 4, 2, 2)
58
59
        latents = latents.permute(0, 3, 1, 4, 2, 5)

60
        latents = latents.reshape(batchsize, channels // (2 * 2), 1, height, width)
61
62
63
64

        return latents

    @torch.no_grad()
65
    def decode(self, latents, input_info):
Watebear's avatar
Watebear committed
66
67
        if self.cpu_offload:
            self.model.to(torch.device("cuda"))
68
69
70
71
72
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
        elif self.config["task"] == "i2i":
            width, height = input_info.auto_width, input_info.auto_hight
        latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"])
73
        latents = latents.to(self.dtype)
74
75
        latents_mean = torch.tensor(self.config["vae_latents_mean"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
        latents_std = 1.0 / torch.tensor(self.config["vae_latents_std"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
76
77
78
        latents = latents / latents_std + latents_mean
        images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
        images = self.image_processor.postprocess(images, output_type="pil")
Watebear's avatar
Watebear committed
79
80
81
82
        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            gc.collect()
83
        return images
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    @staticmethod
    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
    def _pack_latents(latents, batchsize, num_channels_latents, height, width):
        latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
        latents = latents.permute(0, 2, 4, 1, 3, 5)
        latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
        return latents

    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
        if isinstance(generator, list):
            image_latents = [retrieve_latents(self.model.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0])]
            image_latents = torch.cat(image_latents, dim=0)
        else:
            image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax")
99
100
        latents_mean = torch.tensor(self.model.config["latents_mean"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
        latents_std = torch.tensor(self.model.config["latents_std"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
101
102
103
104
        image_latents = (image_latents - latents_mean) / latents_std

        return image_latents

Watebear's avatar
Watebear committed
105
    @torch.no_grad()
106
107
108
109
110
111
    def encode_vae_image(self, image, input_info):
        if self.config["task"] == "i2i":
            self.generator = torch.Generator().manual_seed(input_info.seed)
        elif self.config["task"] == "t2i":
            self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)

Watebear's avatar
Watebear committed
112
113
        if self.cpu_offload:
            self.model.to(torch.device("cuda"))
114
        num_channels_latents = self.config["transformer_in_channels"] // 4
Watebear's avatar
Watebear committed
115
        image = image.to(self.model.device).to(self.dtype)
116

117
118
119
120
        if image.shape[1] != self.latent_channels:
            image_latents = self._encode_vae_image(image=image, generator=self.generator)
        else:
            image_latents = image
121
        if self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] == 0:
122
            # expand init_latents for batchsize
123
            additional_image_per_prompt = self.config["batchsize"] // image_latents.shape[0]
124
            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
125
126
        elif self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] != 0:
            raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config['batchsize']} text prompts.")
127
128
129
130
        else:
            image_latents = torch.cat([image_latents], dim=0)

        image_latent_height, image_latent_width = image_latents.shape[3:]
131
        image_latents = self._pack_latents(image_latents, self.config["batchsize"], num_channels_latents, image_latent_height, image_latent_width)
Watebear's avatar
Watebear committed
132
133
134
135
        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            gc.collect()
136
        return image_latents