pre_infer.py 3.06 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d
import torch.cuda.amp as amp


class WanPreInfer:
    def __init__(self, config):
Dongz's avatar
Dongz committed
9
        assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
helloyongyang's avatar
helloyongyang committed
10
        d = config["dim"] // config["num_heads"]
Dongz's avatar
Dongz committed
11
12

        self.task = config["task"]
helloyongyang's avatar
helloyongyang committed
13
14
15
16
17
18
19
20
21
22
23
24
        self.freqs = torch.cat(
            [
                rope_params(1024, d - 4 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
            ],
            dim=1,
        ).cuda()
        self.freq_dim = config["freq_dim"]
        self.dim = config["dim"]
        self.text_len = config["text_len"]

25
26
27
28
29
30
31
32
33
34
35
36
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

    def infer(self, weights, inputs, positive):
        x = [self.scheduler.latents]
        t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
        if positive:
            context = inputs["text_encoder_output"]["context"]
        else:
            context = inputs["text_encoder_output"]["context_null"]
        seq_len = self.scheduler.seq_len

Dongz's avatar
Dongz committed
37
        if self.task == "i2v":
helloyongyang's avatar
helloyongyang committed
38
39
            clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
            y = [inputs["image_encoder_output"]["vae_encode_out"]]
helloyongyang's avatar
helloyongyang committed
40
41
42
            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]

        # embeddings
TorynCurtis's avatar
TorynCurtis committed
43
        x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
Dongz's avatar
Dongz committed
44
        grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
helloyongyang's avatar
helloyongyang committed
45
46
47
        x = [u.flatten(2).transpose(1, 2) for u in x]
        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
        assert seq_lens.max() <= seq_len
Dongz's avatar
Dongz committed
48
        x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
helloyongyang's avatar
helloyongyang committed
49
50

        embed = sinusoidal_embedding_1d(self.freq_dim, t)
TorynCurtis's avatar
TorynCurtis committed
51
        embed = weights.time_embedding_0.apply(embed)
helloyongyang's avatar
helloyongyang committed
52
        embed = torch.nn.functional.silu(embed)
TorynCurtis's avatar
TorynCurtis committed
53
        embed = weights.time_embedding_2.apply(embed)
helloyongyang's avatar
helloyongyang committed
54
55
        embed0 = torch.nn.functional.silu(embed)

TorynCurtis's avatar
TorynCurtis committed
56
        embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
helloyongyang's avatar
helloyongyang committed
57
58

        # text embeddings
Dongz's avatar
Dongz committed
59
        stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
TorynCurtis's avatar
TorynCurtis committed
60
        out = weights.text_embedding_0.apply(stacked.squeeze(0))
helloyongyang's avatar
helloyongyang committed
61
        out = torch.nn.functional.gelu(out, approximate="tanh")
TorynCurtis's avatar
TorynCurtis committed
62
        context = weights.text_embedding_2.apply(out)
helloyongyang's avatar
helloyongyang committed
63

Dongz's avatar
Dongz committed
64
        if self.task == "i2v":
TorynCurtis's avatar
TorynCurtis committed
65
66
            context_clip = weights.proj_0.apply(clip_fea)
            context_clip = weights.proj_1.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
67
            context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
TorynCurtis's avatar
TorynCurtis committed
68
69
            context_clip = weights.proj_3.apply(context_clip)
            context_clip = weights.proj_4.apply(context_clip)
Dongz's avatar
Dongz committed
70

helloyongyang's avatar
helloyongyang committed
71
            context = torch.concat([context_clip, context], dim=0)
Dongz's avatar
Dongz committed
72

helloyongyang's avatar
helloyongyang committed
73
74
75
76
77
        return (
            embed,
            grid_sizes,
            (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context),
        )