post_infer.py 1009 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import math
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
import torch
PengGao's avatar
PengGao committed
4

gushiqiao's avatar
gushiqiao committed
5
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
6
7
8
9
10
11


class WanPostInfer:
    def __init__(self, config):
        self.out_dim = config["out_dim"]
        self.patch_size = (1, 2, 2)
gushiqiao's avatar
gushiqiao committed
12
        self.clean_cuda_cache = config.get("clean_cuda_cache", False)
helloyongyang's avatar
helloyongyang committed
13

14
15
16
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

helloyongyang's avatar
helloyongyang committed
17
    @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
helloyongyang's avatar
helloyongyang committed
18
19
    def infer(self, weights, x, pre_infer_out):
        x = self.unpatchify(x, pre_infer_out.grid_sizes)
gushiqiao's avatar
gushiqiao committed
20
21
22
23

        if self.clean_cuda_cache:
            torch.cuda.empty_cache()

helloyongyang's avatar
helloyongyang committed
24
25
26
27
28
29
30
31
32
33
34
35
        return [u.float() for u in x]

    def unpatchify(self, x, grid_sizes):
        x = x.unsqueeze(0)
        c = self.out_dim
        out = []
        for u, v in zip(x, grid_sizes.tolist()):
            u = u[: math.prod(v)].view(*v, *self.patch_size, c)
            u = torch.einsum("fhwpqrc->cfphqwr", u)
            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
            out.append(u)
        return out