pre_infer.py 5.74 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
PengGao's avatar
PengGao committed
2

gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
4

5
from .module_io import GridOutput, WanPreInferModuleOutput
6
from .utils import guidance_scale_embedding, sinusoidal_embedding_1d
PengGao's avatar
PengGao committed
7

helloyongyang's avatar
helloyongyang committed
8
9
10

class WanPreInfer:
    def __init__(self, config):
Dongz's avatar
Dongz committed
11
        assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
12
        self.config = config
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
13
        self.run_device = self.config.get("run_device", "cuda")
gushiqiao's avatar
gushiqiao committed
14
        self.clean_cuda_cache = config.get("clean_cuda_cache", False)
Dongz's avatar
Dongz committed
15
        self.task = config["task"]
Kane's avatar
Kane committed
16
        self.device = torch.device(self.config.get("run_device", "cuda"))
helloyongyang's avatar
helloyongyang committed
17
18
        self.freq_dim = config["freq_dim"]
        self.dim = config["dim"]
19
20
        self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
        self.cfg_scale = config.get("cfg_scale", 4.0)
21
22
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
helloyongyang's avatar
helloyongyang committed
23

24
25
26
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
27
    @torch.no_grad()
helloyongyang's avatar
helloyongyang committed
28
    def infer(self, weights, inputs, kv_start=0, kv_end=0):
29
        x = self.scheduler.latents
30
        t = self.scheduler.timestep_input
31

helloyongyang's avatar
helloyongyang committed
32
        if self.scheduler.infer_condition:
33
34
35
36
            context = inputs["text_encoder_output"]["context"]
        else:
            context = inputs["text_encoder_output"]["context_null"]

37
        if self.task in ["i2v", "flf2v", "animate", "s2v"]:
helloyongyang's avatar
helloyongyang committed
38
39
            if self.config.get("use_image_encoder", True):
                clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
wangshankun's avatar
wangshankun committed
40

41
            if self.config.get("changing_resolution", False):
42
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"][self.scheduler.changing_resolution_index]
43
            else:
44
                image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
45

46
47
48
49
50
51
52
53
            if image_encoder is not None:
                frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
                if kv_end - kv_start >= frame_seq_length:  # 如果是CausalVid, image_encoder取片段
                    idx_s = kv_start // frame_seq_length
                    idx_e = kv_end // frame_seq_length
                    image_encoder = image_encoder[:, idx_s:idx_e, :, :]
                y = image_encoder
                x = torch.cat([x, y], dim=0)
helloyongyang's avatar
helloyongyang committed
54
55

        # embeddings
56
        x = weights.patch_embedding.apply(x.unsqueeze(0))
57
58
59
60
61
62

        if hasattr(self, "after_patch_embedding"):
            x, motion_vec = self.after_patch_embedding(weights, x, inputs["image_encoder_output"]["pose_latents"], inputs["image_encoder_output"]["face_pixel_values"])
        else:
            motion_vec = None

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
63
        grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
64
        x = x.flatten(2).transpose(1, 2).contiguous()
65
        # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
helloyongyang's avatar
helloyongyang committed
66

67
        embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
68
        if self.enable_dynamic_cfg:
69
            s = torch.tensor([self.cfg_scale], dtype=torch.float32, device=x.device)
GoatWu's avatar
GoatWu committed
70
71
72
73
            cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x)
            cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed)
            cfg_embed = torch.nn.functional.silu(cfg_embed)
            cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
74
            embed = embed + cfg_embed
75
76
        if self.sensitive_layer_dtype != self.infer_dtype:
            embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
77
78
        else:
            embed = weights.time_embedding_0.apply(embed)
helloyongyang's avatar
helloyongyang committed
79
        embed = torch.nn.functional.silu(embed)
TorynCurtis's avatar
TorynCurtis committed
80
        embed = weights.time_embedding_2.apply(embed)
helloyongyang's avatar
helloyongyang committed
81
82
        embed0 = torch.nn.functional.silu(embed)

TorynCurtis's avatar
TorynCurtis committed
83
        embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
helloyongyang's avatar
helloyongyang committed
84
85

        # text embeddings
86
        if self.sensitive_layer_dtype != self.infer_dtype:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
87
            out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype))
gushiqiao's avatar
gushiqiao committed
88
        else:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
89
            out = weights.text_embedding_0.apply(context.squeeze(0))
helloyongyang's avatar
helloyongyang committed
90
        out = torch.nn.functional.gelu(out, approximate="tanh")
TorynCurtis's avatar
TorynCurtis committed
91
        context = weights.text_embedding_2.apply(out)
gushiqiao's avatar
gushiqiao committed
92
        if self.clean_cuda_cache:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
93
            del out
gushiqiao's avatar
gushiqiao committed
94
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
95

96
        if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
97
98
99
100
            if self.task == "flf2v":
                _, n, d = clip_fea.shape
                clip_fea = clip_fea.view(2 * n, d)
                clip_fea = clip_fea + weights.emb_pos.tensor.squeeze()
TorynCurtis's avatar
TorynCurtis committed
101
            context_clip = weights.proj_0.apply(clip_fea)
102
103
104
            if self.clean_cuda_cache:
                del clip_fea
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
105
            context_clip = weights.proj_1.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
106
            context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
107
108
            if self.clean_cuda_cache:
                torch.cuda.empty_cache()
TorynCurtis's avatar
TorynCurtis committed
109
110
            context_clip = weights.proj_3.apply(context_clip)
            context_clip = weights.proj_4.apply(context_clip)
helloyongyang's avatar
helloyongyang committed
111
            context = torch.concat([context_clip, context], dim=0)
gushiqiao's avatar
gushiqiao committed
112

gushiqiao's avatar
gushiqiao committed
113
        if self.clean_cuda_cache:
helloyongyang's avatar
helloyongyang committed
114
115
            if self.config.get("use_image_encoder", True):
                del context_clip
gushiqiao's avatar
gushiqiao committed
116
            torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
117

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
118
        grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
helloyongyang's avatar
helloyongyang committed
119
120
121
122
123
124
        return WanPreInferModuleOutput(
            embed=embed,
            grid_sizes=grid_sizes,
            x=x.squeeze(0),
            embed0=embed0.squeeze(0),
            context=context,
gushiqiao's avatar
gushiqiao committed
125
            adapter_args={"motion_vec": motion_vec},
helloyongyang's avatar
helloyongyang committed
126
        )