vae_tiny.py 1.03 KB
Newer Older
1
2
import torch
import torch.nn as nn
PengGao's avatar
PengGao committed
3

4
5
from lightx2v.utils.memory_profiler import peak_memory_decorator

PengGao's avatar
PengGao committed
6
7
from ..tae import TAEHV

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

class DotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__


class WanVAE_tiny(nn.Module):
    def __init__(self, vae_pth="taew2_1.pth", dtype=torch.bfloat16, device="cuda"):
        super().__init__()
        self.dtype = dtype
        self.device = torch.device("cuda")
        self.taehv = TAEHV(vae_pth).to(self.dtype)
        self.temperal_downsample = [True, True, False]
        self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))

    @peak_memory_decorator
    @torch.no_grad()
    def decode(self, latents, generator=None, return_dict=None, config=None):
        latents = latents.unsqueeze(0)
        n, c, t, h, w = latents.shape
        # low-memory, set parallel=True for faster + higher memory
        return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)