post_infer.py 1.01 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import torch


Dongz's avatar
Dongz committed
4
class HunyuanPostInfer:
5
6
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
7

8
9
10
11
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

    def infer(self, weights, img, vec):
helloyongyang's avatar
helloyongyang committed
12
13
14
15
16
17
        out = torch.nn.functional.silu(vec)
        out = weights.final_layer_adaLN_modulation_1.apply(out)
        shift, scale = out.chunk(2, dim=1)
        out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
        out = out * (1 + scale) + shift
        out = weights.final_layer_linear.apply(out.to(torch.float32))
18
        _, _, ot, oh, ow = self.scheduler.latents.shape
helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
24
        patch_size = [1, 2, 2]
        tt, th, tw = (
            ot // patch_size[0],
            oh // patch_size[1],
            ow // patch_size[2],
        )
Dongz's avatar
Dongz committed
25

helloyongyang's avatar
helloyongyang committed
26
27
28
29
30
31
32
33
        c = 16
        pt, ph, pw = patch_size

        out = out.reshape(shape=(1, tt, th, tw, c, pt, ph, pw))
        out = torch.einsum("nthwcopq->nctohpwq", out)
        out = out.reshape(shape=(1, c, tt * pt, th * ph, tw * pw))

        return out