hunyuan_vae.py 1.92 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
import torch
PengGao's avatar
PengGao committed
4

helloyongyang's avatar
helloyongyang committed
5
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution
helloyongyang's avatar
helloyongyang committed
6
7


Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
8
class HunyuanVAE:
9
    def __init__(self, model_path, dtype, device, config):
helloyongyang's avatar
helloyongyang committed
10
11
12
        self.model_path = model_path
        self.dtype = dtype
        self.device = device
13
        self.config = config
helloyongyang's avatar
helloyongyang committed
14
15
16
        self.load()

    def load(self):
17
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
18
19
20
            self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
        else:
            self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae")
helloyongyang's avatar
helloyongyang committed
21
22
        config = AutoencoderKLCausal3D.load_config(self.vae_path)
        self.model = AutoencoderKLCausal3D.from_config(config)
Dongz's avatar
Dongz committed
23
        ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
helloyongyang's avatar
helloyongyang committed
24
25
26
27
28
29
30
31
32
33
34
        self.model.load_state_dict(ckpt)
        self.model = self.model.to(dtype=self.dtype, device=self.device)
        self.model.requires_grad_(False)
        self.model.eval()

    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
35
36
    def decode(self, latents):
        if self.config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
37
38
39
40
            self.to_cuda()
        latents = latents / self.model.config.scaling_factor
        latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
        self.model.enable_tiling()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
41
        image = self.model.decode(latents, return_dict=False)[0]
helloyongyang's avatar
helloyongyang committed
42
43
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().float()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
44
        if self.config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
45
46
47
            self.to_cpu()
        return image

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
48
    def encode(self, x):
helloyongyang's avatar
helloyongyang committed
49
50
51
52
53
        h = self.model.encoder(x)
        moments = self.model.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

Dongz's avatar
Dongz committed
54

helloyongyang's avatar
helloyongyang committed
55
if __name__ == "__main__":
56
    model_path = ""
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
57
    vae_model = HunyuanVAE(model_path, dtype=torch.float16, device=torch.device("cuda"))