vae_sf.py 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

from lightx2v.models.video_encoders.hf.wan.vae import _video_vae


class WanSFVAE:
    def __init__(
        self,
        z_dim=16,
        vae_pth="cache/vae_step_411000.pth",
        dtype=torch.float,
        device="cuda",
        parallel=False,
        use_tiling=False,
        cpu_offload=False,
        use_2d_split=True,
        load_from_rank0=False,
    ):
        self.dtype = dtype
        self.device = device
        self.parallel = parallel
        self.use_tiling = use_tiling
        self.cpu_offload = cpu_offload
        self.use_2d_split = use_2d_split

        mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
        std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
        self.mean = torch.tensor(mean, dtype=torch.float32)
        self.std = torch.tensor(std, dtype=torch.float32)

        # init model
        self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
33
        self.model.clear_cache()
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    def to_cpu(self):
        self.model.encoder = self.model.encoder.to("cpu")
        self.model.decoder = self.model.decoder.to("cpu")
        self.model = self.model.to("cpu")
        self.mean = self.mean.cpu()
        self.inv_std = self.inv_std.cpu()
        self.scale = [self.mean, self.inv_std]

    def to_cuda(self):
        self.model.encoder = self.model.encoder.to("cuda")
        self.model.decoder = self.model.decoder.to("cuda")
        self.model = self.model.to("cuda")
        self.mean = self.mean.cuda()
        self.inv_std = self.inv_std.cuda()
        self.scale = [self.mean, self.inv_std]

    def decode(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
        # from [batch_size, num_frames, num_channels, height, width]
        # to [batch_size, num_channels, num_frames, height, width]
        latent = latent.transpose(0, 1).unsqueeze(0)
        zs = latent.permute(0, 2, 1, 3, 4)
        if use_cache:
            assert latent.shape[0] == 1, "Batch size must be 1 when using cache"

        device, dtype = latent.device, latent.dtype
        scale = [self.mean.to(device=device, dtype=dtype), 1.0 / self.std.to(device=device, dtype=dtype)]

        if use_cache:
            decode_function = self.model.cached_decode
        else:
            decode_function = self.model.decode

        output = []
        for u in zs:
            output.append(decode_function(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0))
        output = torch.stack(output, dim=0)
        # from [batch_size, num_channels, num_frames, height, width]
        # to [batch_size, num_frames, num_channels, height, width]
        output = output.permute(0, 2, 1, 3, 4).squeeze(0)
        return output