pre_infer.py 6.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
PengGao's avatar
PengGao committed
2

gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
4

helloyongyang's avatar
helloyongyang committed
5
from .module_io import WanPreInferModuleOutput
PengGao's avatar
PengGao committed
6
7
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d

helloyongyang's avatar
helloyongyang committed
8
9
10

class WanPreInfer:
    def __init__(self, config):
Dongz's avatar
Dongz committed
11
        assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
12
        self.config = config
helloyongyang's avatar
helloyongyang committed
13
        d = config["dim"] // config["num_heads"]
gushiqiao's avatar
gushiqiao committed
14
        self.clean_cuda_cache = config.get("clean_cuda_cache", False)
Dongz's avatar
Dongz committed
15
        self.task = config["task"]
helloyongyang's avatar
helloyongyang committed
16
17
18
19
20
21
22
23
24
25
26
        self.freqs = torch.cat(
            [
                rope_params(1024, d - 4 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
            ],
            dim=1,
        ).cuda()
        self.freq_dim = config["freq_dim"]
        self.dim = config["dim"]
        self.text_len = config["text_len"]
27
28
        self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
        self.cfg_scale = config.get("cfg_scale", 4.0)
29
30
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
helloyongyang's avatar
helloyongyang committed
31

32
33
34
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

helloyongyang's avatar
helloyongyang committed
35
    @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
helloyongyang's avatar
helloyongyang committed
36
    def infer(self, weights, inputs, kv_start=0, kv_end=0):
37
        x = self.scheduler.latents
38
39
40
41
42

        if self.scheduler.flag_df:
            t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
            assert t.dim() == 2  # df推理模型timestep是二维
        else:
43
44
            timestep = self.scheduler.timesteps[self.scheduler.step_index]
            t = torch.stack([timestep])
helloyongyang's avatar
helloyongyang committed
45
            if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
46
                t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten()
47

helloyongyang's avatar
helloyongyang committed
48
        if self.scheduler.infer_condition:
49
50
51
52
            context = inputs["text_encoder_output"]["context"]
        else:
            context = inputs["text_encoder_output"]["context_null"]

Dongz's avatar
Dongz committed
53
        if self.task == "i2v":
helloyongyang's avatar
helloyongyang committed
54
55
            if self.config.get("use_image_encoder", True):
                clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
wangshankun's avatar
wangshankun committed
56

57
            if self.config.get("changing_resolution", False):
58
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"][self.scheduler.changing_resolution_index]
59
            else:
60
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
61

62
63
64
65
66
67
68
69
            if image_encoder is not None:
                frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
                if kv_end - kv_start >= frame_seq_length:  # 如果是CausalVid, image_encoder取片段
                    idx_s = kv_start // frame_seq_length
                    idx_e = kv_end // frame_seq_length
                    image_encoder = image_encoder[:, idx_s:idx_e, :, :]
                y = image_encoder
                x = torch.cat([x, y], dim=0)
helloyongyang's avatar
helloyongyang committed
70
71

        # embeddings
72
73
74
75
        x = weights.patch_embedding.apply(x.unsqueeze(0))
        grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
        x = x.flatten(2).transpose(1, 2).contiguous()
        seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
helloyongyang's avatar
helloyongyang committed
76

77
        embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
78
79
        if self.enable_dynamic_cfg:
            s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
GoatWu's avatar
GoatWu committed
80
81
82
83
            cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x)
            cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed)
            cfg_embed = torch.nn.functional.silu(cfg_embed)
            cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
84
            embed = embed + cfg_embed
85
86
        if self.sensitive_layer_dtype != self.infer_dtype:
            embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
87
88
        else:
            embed = weights.time_embedding_0.apply(embed)
helloyongyang's avatar
helloyongyang committed
89
        embed = torch.nn.functional.silu(embed)
TorynCurtis's avatar
TorynCurtis committed
90
        embed = weights.time_embedding_2.apply(embed)
helloyongyang's avatar
helloyongyang committed
91
92
        embed0 = torch.nn.functional.silu(embed)

TorynCurtis's avatar
TorynCurtis committed
93
        embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
helloyongyang's avatar
helloyongyang committed
94

95
96
97
98
99
100
101
102
103
        if self.scheduler.flag_df:
            b, f = t.shape
            assert b == len(x)  # batch_size == 1
            embed = embed.view(b, f, 1, 1, self.dim)
            embed0 = embed0.view(b, f, 1, 1, 6, self.dim)
            embed = embed.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1).flatten(1, 3)
            embed0 = embed0.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1, 1).flatten(1, 3)
            embed0 = embed0.transpose(1, 2).contiguous()

helloyongyang's avatar
helloyongyang committed
104
        # text embeddings
Dongz's avatar
Dongz committed
105
        stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
106
107
        if self.sensitive_layer_dtype != self.infer_dtype:
            out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
108
109
        else:
            out = weights.text_embedding_0.apply(stacked.squeeze(0))
helloyongyang's avatar
helloyongyang committed
110
        out = torch.nn.functional.gelu(out, approximate="tanh")
TorynCurtis's avatar
TorynCurtis committed
111
        context = weights.text_embedding_2.apply(out)
gushiqiao's avatar
gushiqiao committed
112
113
114
        if self.clean_cuda_cache:
            del out, stacked
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
115

helloyongyang's avatar
helloyongyang committed
116
        if self.task == "i2v" and self.config.get("use_image_encoder", True):
TorynCurtis's avatar
TorynCurtis committed
117
            context_clip = weights.proj_0.apply(clip_fea)
118
119
120
            if self.clean_cuda_cache:
                del clip_fea
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
121
            context_clip = weights.proj_1.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
122
            context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
123
124
            if self.clean_cuda_cache:
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
125
126
            context_clip = weights.proj_3.apply(context_clip)
            context_clip = weights.proj_4.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
127
            context = torch.concat([context_clip, context], dim=0)
gushiqiao's avatar
gushiqiao committed
128
        if self.clean_cuda_cache:
helloyongyang's avatar
helloyongyang committed
129
130
            if self.config.get("use_image_encoder", True):
                del context_clip
gushiqiao's avatar
gushiqiao committed
131
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
132
133
134
135
136
137
138
139
140

        return WanPreInferModuleOutput(
            embed=embed,
            grid_sizes=grid_sizes,
            x=x.squeeze(0),
            embed0=embed0.squeeze(0),
            seq_lens=seq_lens,
            freqs=self.freqs,
            context=context,
helloyongyang's avatar
helloyongyang committed
141
        )