pre_infer.py 10.7 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import torch

from lightx2v.models.networks.ltx2.infer.module_io import LTX2PreInferModuleOutput, TransformerArgs
from lightx2v.models.networks.ltx2.infer.utils import *
from lightx2v.utils.envs import *


class LTX2PreInfer:
    """
    Pre-inference module for LTX2 transformer.

    Handles all input preprocessing before transformer blocks.
    """

    def __init__(self, config):
        self.config = config
        # Video config
        self.num_attention_heads = self.config["num_attention_heads"]
        self.inner_dim = self.config["num_attention_heads"] * config["attention_head_dim"]
        self.positional_embedding_max_pos = [20, 2048, 2048]

        # Audio config
        self.audio_num_attention_heads = self.config.get("audio_num_attention_heads", 32)
        self.audio_attention_head_dim = self.config.get("audio_attention_head_dim", 64)
        self.audio_inner_dim = self.audio_num_attention_heads * self.audio_attention_head_dim
        self.audio_cross_attention_dim = self.config["audio_cross_attention_dim"]
        self.audio_positional_embedding_max_pos = [config["audio_pos_embed_max_pos"]]

        # Common config
        self.timestep_scale_multiplier = self.config["timestep_scale_multiplier"]
        self.av_ca_timestep_scale_multiplier = self.config["cross_attn_timestep_scale_multiplier"]
        self.double_precision_rope = self.config.get("double_precision_rope", False)

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler

    def _prepare_positional_embeddings(
        self,
        positions: torch.Tensor,
        inner_dim: int,
        max_pos: list[int],
        use_middle_indices_grid: bool,
        num_attention_heads: int,
        x_dtype: torch.dtype,
    ) -> torch.Tensor:
        """Prepare positional embeddings."""
        freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
        pe = precompute_freqs_cis(
            positions,
            dim=inner_dim,
            out_dtype=x_dtype,
            theta=10000.0,
            max_pos=max_pos,
            use_middle_indices_grid=use_middle_indices_grid,
            num_attention_heads=num_attention_heads,
            rope_type="split",
            freq_grid_generator=freq_grid_generator,
        )
        return pe

    def __infer_video(self, weights, inputs, av_ca_factor):
        """Process video modality data."""
        # Get video modality data
        v_latent = self.scheduler.video_latent_state.latent
        v_positions = self.scheduler.video_latent_state.positions
        v_timesteps = self.scheduler.video_timesteps_from_mask()
        if self.scheduler.infer_condition:
            v_context = inputs["text_encoder_output"]["v_context_p"]
        else:
            v_context = inputs["text_encoder_output"]["v_context_n"]

        # 1. Patchify projection
        video_x = weights.patchify_proj.apply(v_latent)

        # 2. Timestep embeddings (adaln)
        v_timestep = v_timesteps * self.timestep_scale_multiplier
        v_timesteps_proj = get_timestep_embedding(v_timestep.flatten()).to(GET_DTYPE())

        v_emb_linear_1_out = weights.adaln_single_emb_timestep_embedder_linear_1.apply(v_timesteps_proj)
        v_emb_linear_1_out = torch.nn.functional.silu(v_emb_linear_1_out)
        v_embedded_timestep = weights.adaln_single_emb_timestep_embedder_linear_2.apply(v_emb_linear_1_out)
        v_adaln_timestep = weights.adaln_single_linear.apply(torch.nn.functional.silu(v_embedded_timestep))

        # 3. Caption projection
        v_context = weights.caption_projection_linear_1.apply(v_context.squeeze(0))
        v_context = torch.nn.functional.gelu(v_context, approximate="tanh")
        v_context = weights.caption_projection_linear_2.apply(v_context)

        # 4. Positional embeddings
        v_pe = self._prepare_positional_embeddings(
            positions=v_positions.unsqueeze(0),  # No unsqueeze, pass directly
            inner_dim=self.inner_dim,
            max_pos=self.positional_embedding_max_pos,
            use_middle_indices_grid=True,
            num_attention_heads=self.num_attention_heads,
            x_dtype=v_latent.dtype,
        )

        # 5. Cross-attention positional embeddings
        v_cross_pe = self._prepare_positional_embeddings(
            positions=v_positions.unsqueeze(0)[:, 0:1, :],  # No unsqueeze, directly slice
            inner_dim=self.audio_cross_attention_dim,
            max_pos=[20],
            use_middle_indices_grid=True,
            num_attention_heads=self.num_attention_heads,
            x_dtype=v_latent.dtype,
        )

        # 6. Cross-attention timestep embeddings
        # Video cross scale-shift AdaLN
        v_cross_proj = get_timestep_embedding(v_timestep.flatten()).to(GET_DTYPE())
        v_cross_emb_1 = weights.av_ca_video_scale_shift_adaln_single_emb_linear_1.apply(v_cross_proj)
        v_cross_emb_1 = torch.nn.functional.silu(v_cross_emb_1)
        v_cross_emb_2 = weights.av_ca_video_scale_shift_adaln_single_emb_linear_2.apply(v_cross_emb_1)
        v_cross_scale_shift_timestep = weights.av_ca_video_scale_shift_adaln_single_linear.apply(torch.nn.functional.silu(v_cross_emb_2))

        # Video cross gate AdaLN
        v_gate_timestep_flat = v_timestep.flatten() * av_ca_factor
        v_gate_proj = get_timestep_embedding(v_gate_timestep_flat).to(GET_DTYPE())
        v_gate_emb_1 = weights.av_ca_a2v_gate_adaln_single_emb_linear_1.apply(v_gate_proj)
        v_gate_emb_1 = torch.nn.functional.silu(v_gate_emb_1)
        v_gate_emb_2 = weights.av_ca_a2v_gate_adaln_single_emb_linear_2.apply(v_gate_emb_1)
        v_cross_gate_timestep = weights.av_ca_a2v_gate_adaln_single_linear.apply(torch.nn.functional.silu(v_gate_emb_2))

        # Return TransformerArgs structure
        return TransformerArgs(
            x=video_x,
            context=v_context,
            context_mask=None,
            timesteps=v_adaln_timestep,
            embedded_timestep=v_embedded_timestep,
            positional_embeddings=v_pe,
            cross_positional_embeddings=v_cross_pe,
            cross_scale_shift_timestep=v_cross_scale_shift_timestep,
            cross_gate_timestep=v_cross_gate_timestep,
        )

    def __infer_audio(self, weights, inputs, av_ca_factor):
        """Process audio modality data."""
        # Get audio modality data
        a_latent = self.scheduler.audio_latent_state.latent
        a_positions = self.scheduler.audio_latent_state.positions
        a_timesteps = self.scheduler.audio_timesteps_from_mask()
        if self.scheduler.infer_condition:
            a_context = inputs["text_encoder_output"]["a_context_p"]
        else:
            a_context = inputs["text_encoder_output"]["a_context_n"]

        # 1. Audio patchify projection
        audio_x = weights.audio_patchify_proj.apply(a_latent)

        # 2. Audio timestep embeddings (adaln)
        a_timestep = a_timesteps * self.timestep_scale_multiplier
        a_timesteps_proj = get_timestep_embedding(a_timestep.flatten()).to(GET_DTYPE())

        a_emb_linear_1_out = weights.audio_adaln_single_emb_timestep_embedder_linear_1.apply(a_timesteps_proj)
        a_emb_linear_1_out = torch.nn.functional.silu(a_emb_linear_1_out)
        a_embedded_timestep = weights.audio_adaln_single_emb_timestep_embedder_linear_2.apply(a_emb_linear_1_out)
        a_adaln_timestep = weights.audio_adaln_single_linear.apply(torch.nn.functional.silu(a_embedded_timestep))

        # 3. Audio caption projection
        a_context = weights.audio_caption_projection_linear_1.apply(a_context.squeeze(0))
        a_context = torch.nn.functional.gelu(a_context, approximate="tanh")
        a_context = weights.audio_caption_projection_linear_2.apply(a_context)

        # 4. Audio positional embeddings
        # Note: audio positions already have batch dim [B, 1, T, 2], unlike video [3, num_patches, 2]
        a_pe = self._prepare_positional_embeddings(
            positions=a_positions,  # Already has batch dim
            inner_dim=self.audio_inner_dim,
            max_pos=self.audio_positional_embedding_max_pos,
            use_middle_indices_grid=True,
            num_attention_heads=self.audio_num_attention_heads,
            x_dtype=a_latent.dtype,
        )

        # 5. Audio cross-attention positional embeddings (for video cross attention)
        a_cross_pe = self._prepare_positional_embeddings(
            positions=a_positions[:, 0:1, :],  # Already has batch dim
            inner_dim=self.audio_cross_attention_dim,  # Use audio_cross_attention_dim
            max_pos=[20],
            use_middle_indices_grid=True,
            num_attention_heads=self.audio_num_attention_heads,
            x_dtype=a_latent.dtype,
        )

        # 6. Audio cross-attention timestep embeddings
        # Audio cross scale-shift AdaLN
        a_cross_proj = get_timestep_embedding(a_timestep.flatten()).to(GET_DTYPE())
        a_cross_emb_1 = weights.av_ca_audio_scale_shift_adaln_single_emb_linear_1.apply(a_cross_proj)
        a_cross_emb_1 = torch.nn.functional.silu(a_cross_emb_1)
        a_cross_emb_2 = weights.av_ca_audio_scale_shift_adaln_single_emb_linear_2.apply(a_cross_emb_1)
        a_cross_scale_shift_timestep = weights.av_ca_audio_scale_shift_adaln_single_linear.apply(torch.nn.functional.silu(a_cross_emb_2))

        # Audio cross gate AdaLN
        a_gate_timestep_flat = a_timestep.flatten() * av_ca_factor
        a_gate_proj = get_timestep_embedding(a_gate_timestep_flat).to(GET_DTYPE())
        a_gate_emb_1 = weights.av_ca_v2a_gate_adaln_single_emb_linear_1.apply(a_gate_proj)
        a_gate_emb_1 = torch.nn.functional.silu(a_gate_emb_1)
        a_gate_emb_2 = weights.av_ca_v2a_gate_adaln_single_emb_linear_2.apply(a_gate_emb_1)
        a_cross_gate_timestep = weights.av_ca_v2a_gate_adaln_single_linear.apply(torch.nn.functional.silu(a_gate_emb_2))
        # Return TransformerArgs structure
        return TransformerArgs(
            x=audio_x,
            context=a_context,
            context_mask=None,
            timesteps=a_adaln_timestep,
            embedded_timestep=a_embedded_timestep,
            positional_embeddings=a_pe,
            cross_positional_embeddings=a_cross_pe,
            cross_scale_shift_timestep=a_cross_scale_shift_timestep,
            cross_gate_timestep=a_cross_gate_timestep,
        )

    @torch.no_grad()
    def infer(self, weights, inputs):
        """Main inference entry point."""
        # Calculate AV cross-attention factor (used by both video and audio)
        av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier

        # Process video and audio modalities
        video_args = self.__infer_video(weights, inputs, av_ca_factor)
        audio_args = self.__infer_audio(weights, inputs, av_ca_factor)
        return LTX2PreInferModuleOutput(
            video_args=video_args,
            audio_args=audio_args,
        )