pre_infer.py 5.63 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
PengGao's avatar
PengGao committed
2

gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
4

5
from .module_io import GridOutput, WanPreInferModuleOutput
PengGao's avatar
PengGao committed
6
7
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d

helloyongyang's avatar
helloyongyang committed
8
9
10

class WanPreInfer:
    def __init__(self, config):
Dongz's avatar
Dongz committed
11
        assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
12
        self.config = config
helloyongyang's avatar
helloyongyang committed
13
        d = config["dim"] // config["num_heads"]
gushiqiao's avatar
gushiqiao committed
14
        self.clean_cuda_cache = config.get("clean_cuda_cache", False)
Dongz's avatar
Dongz committed
15
        self.task = config["task"]
helloyongyang's avatar
helloyongyang committed
16
17
18
19
20
21
22
23
24
25
        self.freqs = torch.cat(
            [
                rope_params(1024, d - 4 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
                rope_params(1024, 2 * (d // 6)),
            ],
            dim=1,
        ).cuda()
        self.freq_dim = config["freq_dim"]
        self.dim = config["dim"]
26
27
        self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
        self.cfg_scale = config.get("cfg_scale", 4.0)
28
29
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
helloyongyang's avatar
helloyongyang committed
30

31
32
33
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
34
    @torch.no_grad()
helloyongyang's avatar
helloyongyang committed
35
    def infer(self, weights, inputs, kv_start=0, kv_end=0):
36
        x = self.scheduler.latents
37
        t = self.scheduler.timestep_input
38

helloyongyang's avatar
helloyongyang committed
39
        if self.scheduler.infer_condition:
40
41
42
43
            context = inputs["text_encoder_output"]["context"]
        else:
            context = inputs["text_encoder_output"]["context_null"]

gushiqiao's avatar
gushiqiao committed
44
        if self.task in ["i2v", "flf2v"]:
helloyongyang's avatar
helloyongyang committed
45
46
            if self.config.get("use_image_encoder", True):
                clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
wangshankun's avatar
wangshankun committed
47

48
            if self.config.get("changing_resolution", False):
49
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"][self.scheduler.changing_resolution_index]
50
            else:
51
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
52

53
54
55
56
57
58
59
60
            if image_encoder is not None:
                frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
                if kv_end - kv_start >= frame_seq_length:  # 如果是CausalVid, image_encoder取片段
                    idx_s = kv_start // frame_seq_length
                    idx_e = kv_end // frame_seq_length
                    image_encoder = image_encoder[:, idx_s:idx_e, :, :]
                y = image_encoder
                x = torch.cat([x, y], dim=0)
helloyongyang's avatar
helloyongyang committed
61
62

        # embeddings
63
        x = weights.patch_embedding.apply(x.unsqueeze(0))
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
64
        grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
65
        x = x.flatten(2).transpose(1, 2).contiguous()
66
        seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
helloyongyang's avatar
helloyongyang committed
67

68
        embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
69
        if self.enable_dynamic_cfg:
70
            s = torch.tensor([self.cfg_scale], dtype=torch.float32, device=x.device)
GoatWu's avatar
GoatWu committed
71
72
73
74
            cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x)
            cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed)
            cfg_embed = torch.nn.functional.silu(cfg_embed)
            cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
75
            embed = embed + cfg_embed
76
77
        if self.sensitive_layer_dtype != self.infer_dtype:
            embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
78
79
        else:
            embed = weights.time_embedding_0.apply(embed)
helloyongyang's avatar
helloyongyang committed
80
        embed = torch.nn.functional.silu(embed)
TorynCurtis's avatar
TorynCurtis committed
81
        embed = weights.time_embedding_2.apply(embed)
helloyongyang's avatar
helloyongyang committed
82
83
        embed0 = torch.nn.functional.silu(embed)

TorynCurtis's avatar
TorynCurtis committed
84
        embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
helloyongyang's avatar
helloyongyang committed
85
86

        # text embeddings
87
        if self.sensitive_layer_dtype != self.infer_dtype:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
88
            out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
89
        else:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
90
            out = weights.text_embedding_0.apply(context.squeeze(0))
helloyongyang's avatar
helloyongyang committed
91
        out = torch.nn.functional.gelu(out, approximate="tanh")
TorynCurtis's avatar
TorynCurtis committed
92
        context = weights.text_embedding_2.apply(out)
gushiqiao's avatar
gushiqiao committed
93
        if self.clean_cuda_cache:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
94
            del out
gushiqiao's avatar
gushiqiao committed
95
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
96

gushiqiao's avatar
gushiqiao committed
97
98
99
100
101
        if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
            if self.task == "flf2v":
                _, n, d = clip_fea.shape
                clip_fea = clip_fea.view(2 * n, d)
                clip_fea = clip_fea + weights.emb_pos.tensor.squeeze()
TorynCurtis's avatar
TorynCurtis committed
102
            context_clip = weights.proj_0.apply(clip_fea)
103
104
105
            if self.clean_cuda_cache:
                del clip_fea
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
106
            context_clip = weights.proj_1.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
107
            context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
108
109
            if self.clean_cuda_cache:
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
110
111
            context_clip = weights.proj_3.apply(context_clip)
            context_clip = weights.proj_4.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
112
            context = torch.concat([context_clip, context], dim=0)
gushiqiao's avatar
gushiqiao committed
113

gushiqiao's avatar
gushiqiao committed
114
        if self.clean_cuda_cache:
helloyongyang's avatar
helloyongyang committed
115
116
            if self.config.get("use_image_encoder", True):
                del context_clip
gushiqiao's avatar
gushiqiao committed
117
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
118

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
119
        grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
helloyongyang's avatar
helloyongyang committed
120
121
122
123
124
125
126
127
        return WanPreInferModuleOutput(
            embed=embed,
            grid_sizes=grid_sizes,
            x=x.squeeze(0),
            embed0=embed0.squeeze(0),
            seq_lens=seq_lens,
            freqs=self.freqs,
            context=context,
helloyongyang's avatar
helloyongyang committed
128
        )