post_infer.py 1.64 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
import math
import torch
gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
4
5
6
7
8
9


class WanPostInfer:
    def __init__(self, config):
        self.out_dim = config["out_dim"]
        self.patch_size = (1, 2, 2)
gushiqiao's avatar
gushiqiao committed
10
        self.clean_cuda_cache = config.get("clean_cuda_cache", False)
helloyongyang's avatar
helloyongyang committed
11

12
13
14
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

helloyongyang's avatar
helloyongyang committed
15
    @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
helloyongyang's avatar
helloyongyang committed
16
    def infer(self, weights, x, e, grid_sizes):
17
        if e.dim() == 2:
18
            modulation = weights.head_modulation.tensor  # 1, 2, dim
19
20
            e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
        elif e.dim() == 3:  # For Diffustion forcing
21
            modulation = weights.head_modulation.tensor.unsqueeze(2)  # 1, 2, seq, dim
22
23
24
            e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
            e = [ei.squeeze(1) for ei in e]

gushiqiao's avatar
gushiqiao committed
25
        x = weights.norm.apply(x)
gushiqiao's avatar
gushiqiao committed
26
27

        if GET_DTYPE() != "BF16":
gushiqiao's avatar
gushiqiao committed
28
            x = x.float()
29
        x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
gushiqiao's avatar
gushiqiao committed
30
        if GET_DTYPE() != "BF16":
gushiqiao's avatar
gushiqiao committed
31
            x = x.to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
32

gushiqiao's avatar
gushiqiao committed
33
        x = weights.head.apply(x)
helloyongyang's avatar
helloyongyang committed
34
        x = self.unpatchify(x, grid_sizes)
gushiqiao's avatar
gushiqiao committed
35
36
37
38
39

        if self.clean_cuda_cache:
            del e, grid_sizes
            torch.cuda.empty_cache()

helloyongyang's avatar
helloyongyang committed
40
41
42
43
44
45
46
47
48
49
50
51
        return [u.float() for u in x]

    def unpatchify(self, x, grid_sizes):
        x = x.unsqueeze(0)
        c = self.out_dim
        out = []
        for u, v in zip(x, grid_sizes.tolist()):
            u = u[: math.prod(v)].view(*v, *self.patch_size, c)
            u = torch.einsum("fhwpqrc->cfphqwr", u)
            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
            out.append(u)
        return out