pre_infer.py 7.27 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import math
PengGao's avatar
PengGao committed
2
3

import torch
helloyongyang's avatar
helloyongyang committed
4
5
6
from einops import rearrange


Dongz's avatar
Dongz committed
7
class HunyuanPreInfer:
8
    def __init__(self, config):
helloyongyang's avatar
helloyongyang committed
9
        self.heads_num = 24
10
        self.config = config
helloyongyang's avatar
helloyongyang committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

    def infer(self, weights, inputs):
        x = self.scheduler.latents
        t = self.scheduler.timesteps[self.scheduler.step_index]
        freqs_cos = self.scheduler.freqs_cos
        freqs_sin = self.scheduler.freqs_sin
        guidance = self.scheduler.guidance

        text_states = inputs["text_encoder_output"]["text_encoder_1_text_states"]
        text_mask = inputs["text_encoder_output"]["text_encoder_1_attention_mask"]
        text_states_2 = inputs["text_encoder_output"]["text_encoder_2_text_states"]

        if self.config["task"] == "i2v":
helloyongyang's avatar
helloyongyang committed
27
28
29
30
31
32
            token_replace_t = torch.zeros_like(t)
            token_replace_vec = self.infer_time_in(weights, token_replace_t)
            th = x.shape[-2] // 2
            tw = x.shape[-1] // 2
            frist_frame_token_num = th * tw

helloyongyang's avatar
helloyongyang committed
33
34
35
36
37
        time_out = self.infer_time_in(weights, t)
        img_out = self.infer_img_in(weights, x)
        infer_text_out = self.infer_text_in(weights, text_states, text_mask, t)
        infer_vector_out = self.infer_vector_in(weights, text_states_2)
        vec = time_out + infer_vector_out
helloyongyang's avatar
helloyongyang committed
38

39
        if self.config["task"] == "i2v":
helloyongyang's avatar
helloyongyang committed
40
41
            token_replace_vec = token_replace_vec + infer_vector_out

helloyongyang's avatar
helloyongyang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        guidance_out = self.infer_guidance_in(weights, guidance)
        vec = vec + guidance_out

        txt_seq_len = infer_text_out.shape[0]
        img_seq_len = img_out.shape[1]
        batch_size = text_mask.shape[0]
        text_len = text_mask.sum(dim=1)
        max_len = text_mask.shape[1] + img_seq_len

        cu_seqlens_qkv = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
        for i in range(batch_size):
            s = text_len[i] + img_seq_len
            s1 = i * max_len + s
            s2 = (i + 1) * max_len
            cu_seqlens_qkv[2 * i + 1] = s1
            cu_seqlens_qkv[2 * i + 2] = s2
Dongz's avatar
Dongz committed
58

helloyongyang's avatar
helloyongyang committed
59
        max_seqlen_qkv = img_seq_len + txt_seq_len
60
        if self.config["task"] == "i2v":
helloyongyang's avatar
helloyongyang committed
61
            return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
helloyongyang's avatar
helloyongyang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)

    def infer_time_in(self, weights, t):
        freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
        args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
        out = weights.time_in_mlp_0.apply(embedding)
        out = torch.nn.functional.silu(out)
        out = weights.time_in_mlp_2.apply(out)
        return out

    def infer_img_in(self, weights, x):
        out = weights.img_in_proj.apply(x)
        out = out.flatten(2).transpose(1, 2)
        return out

    def infer_text_in(self, weights, text_states, text_mask, t):
        freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
        args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
        out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
        out = torch.nn.functional.silu(out)
        timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
Dongz's avatar
Dongz committed
85

helloyongyang's avatar
helloyongyang committed
86
87
88
        mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16)  # [b, s1, 1]
        context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
        context_aware_representations = context_aware_representations
Dongz's avatar
Dongz committed
89

helloyongyang's avatar
helloyongyang committed
90
91
92
93
        out = weights.txt_in_c_embedder_linear_1.apply(context_aware_representations)
        out = torch.nn.functional.silu(out)
        context_aware_representations = weights.txt_in_c_embedder_linear_2.apply(out)
        c = timestep_aware_representations + context_aware_representations
Dongz's avatar
Dongz committed
94

helloyongyang's avatar
helloyongyang committed
95
        txt_in_input_embed = weights.txt_in_input_embedder.apply(text_states[0])
Dongz's avatar
Dongz committed
96

helloyongyang's avatar
helloyongyang committed
97
98
        batch_size = text_mask.shape[0]
        seq_len = text_mask.shape[1]
Dongz's avatar
Dongz committed
99
        self_attn_mask_1 = text_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
helloyongyang's avatar
helloyongyang committed
100
101
102
103
104
105
106
107
108
109
        self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
        self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
        self_attn_mask[:, :, :, 0] = True

        cx = torch.nn.functional.silu(c)
        cx = weights.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1.apply(cx)
        gate_msa, gate_mlp = cx.chunk(2, dim=1)
        normx = weights.txt_in_individual_token_refiner_blocks_0_norm1.apply(txt_in_input_embed)
        qkv = weights.txt_in_individual_token_refiner_blocks_0_self_attn_qkv.apply(normx)
        q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
110
        attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
helloyongyang's avatar
helloyongyang committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        out = weights.txt_in_individual_token_refiner_blocks_0_self_attn_proj.apply(attn)
        out_1 = txt_in_input_embed + out * gate_msa
        out = weights.txt_in_individual_token_refiner_blocks_0_norm2.apply(out_1)
        # mlp
        out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc1.apply(out)
        out = torch.nn.functional.silu(out)
        out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc2.apply(out)
        txt_in_input_embed = out_1 + out * gate_mlp

        cx = torch.nn.functional.silu(c)
        cx = weights.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1.apply(cx)
        gate_msa, gate_mlp = cx.chunk(2, dim=1)

        normx = weights.txt_in_individual_token_refiner_blocks_1_norm1.apply(txt_in_input_embed)
        qkv = weights.txt_in_individual_token_refiner_blocks_1_self_attn_qkv.apply(normx)

        q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)

129
        attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
helloyongyang's avatar
helloyongyang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        out = weights.txt_in_individual_token_refiner_blocks_1_self_attn_proj.apply(attn)
        out_1 = txt_in_input_embed + out * gate_msa

        out = weights.txt_in_individual_token_refiner_blocks_1_norm2.apply(out_1)
        # mlp
        out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc1.apply(out)
        out = torch.nn.functional.silu(out)
        out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc2.apply(out)

        out = out_1 + out * gate_mlp
        return out

    def infer_vector_in(self, weights, text_states_2):
        out = weights.vector_in_in_layer.apply(text_states_2)
        out = torch.nn.functional.silu(out)
        out = weights.vector_in_out_layer.apply(out)
        return out

    def infer_guidance_in(self, weights, guidance):
        freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device)
        args = guidance.float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
        out = weights.guidance_in_mlp_0.apply(embedding)
        out = torch.nn.functional.silu(out)
        out = weights.guidance_in_mlp_2.apply(out)
        return out