vae.py 4.13 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gc
import os

import torch

from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE

try:
    from diffusers import AutoencoderKL
    from diffusers.image_processor import VaeImageProcessor
except ImportError:
    AutoencoderKL = None
    VaeImageProcessor = None

torch_device_module = getattr(torch, AI_DEVICE)


class LongCatImageVAE:
    """VAE for LongCat Image model.

    Uses standard AutoencoderKL with LongCat-specific scaling/shift factors.
    Unlike QwenImage which uses 2x2 packing, LongCat uses patch_size=1.
    """

    def __init__(self, config):
        self.config = config
        self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False))
        if self.cpu_offload:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device(AI_DEVICE)
        self.dtype = GET_DTYPE()

        # LongCat VAE parameters
        # Note: VAE uses 16 latent channels, which becomes 64 after 2x2 packing for transformer input
        self.latent_channels = config.get("latent_channels", 16)
        self.vae_scale_factor = config.get("vae_scale_factor", 8)

        # Scaling factors for LongCat
        # From diffusers LongCatImagePipeline: scaling_factor=0.3611, shift_factor=0.1159
        self.scaling_factor = 0.3611
        self.shift_factor = 0.1159

        self.load()

    def load(self):
        """Load the VAE model."""
        vae_path = self.config.get("vae_path", os.path.join(self.config["model_path"], "vae"))
        self.model = AutoencoderKL.from_pretrained(vae_path).to(self.device).to(self.dtype)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

        if self.config.get("use_tiling_vae", False):
            self.model.enable_tiling()

    @staticmethod
    def _unpack_latents(latents, height, width):
        """Unpack latents from [B, (H//2)*(W//2), C*4] to [B, C, H, W].

        Reverses the 2x2 packing used by Flux/LongCat.
        height and width are the packed spatial dims (half of VAE latent dims).
        """
        # Handle both [B, L, C] and [L, C] formats
        if latents.dim() == 2:
            latents = latents.unsqueeze(0)

        batch_size = latents.shape[0]
        # Packed channels = 64, original channels = 16
        channels = latents.shape[-1] // 4
        # -> [B, H//2, W//2, C, 2, 2]
        latents = latents.view(batch_size, height, width, channels, 2, 2)
        # -> [B, C, H//2, 2, W//2, 2]
        latents = latents.permute(0, 3, 1, 4, 2, 5)
        # -> [B, C, H, W] where H and W are full latent dims
        latents = latents.reshape(batch_size, channels, height * 2, width * 2)
        return latents

    @torch.no_grad()
    def decode(self, latents, input_info):
        """Decode latents to images.

        Args:
            latents: Latent tensor [B, (H//2)*(W//2), 64] - packed format
            input_info: Input information containing dimensions

        Returns:
            List of PIL images
        """
        if self.cpu_offload:
            self.model.to(AI_DEVICE)

        width, height = input_info.auto_width, input_info.auto_height
        # Full VAE latent dimensions
        full_latent_height = height // self.vae_scale_factor
        full_latent_width = width // self.vae_scale_factor
        # Packed dimensions (half due to 2x2 packing)
        packed_height = full_latent_height // 2
        packed_width = full_latent_width // 2

        # Unpack latents: [B, (H//2)*(W//2), 64] -> [B, 16, H, W]
        latents = self._unpack_latents(latents, packed_height, packed_width)
        latents = latents.to(self.dtype)

        # Apply inverse scaling: latents = (latents / scaling_factor) + shift_factor
        latents = (latents / self.scaling_factor) + self.shift_factor

        # Decode - latents is now [B, 16, H, W]
        images = self.model.decode(latents, return_dict=False)[0]
        images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")

        if self.cpu_offload:
            self.model.to(torch.device("cpu"))
            torch_device_module.empty_cache()
            gc.collect()

        return images