"examples/vscode:/vscode.git/clone" did not exist on "5fd24ec02e0365f96301ac73a31ef06976c256e8"
phi4_multimodal.py 53 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Any, Literal, Optional, Union
6
7
8
9
10

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    Phi4MultimodalAudioConfig,
    Phi4MultimodalConfig,
    Phi4MultimodalFeatureExtractor,
    Phi4MultimodalImageProcessorFast,
)
18
19
from transformers import Phi4MultimodalProcessor as Phi4MMProcessor
from transformers.models.phi4_multimodal.modeling_phi4_multimodal import (
20
21
22
23
24
25
    Phi4MultimodalAudioConvModule,
    Phi4MultimodalAudioNemoConvSubsampling,
    Phi4MultimodalAudioRelativeAttentionBias,
    adaptive_enc_mask,
    unfold_tensor,
)
26
27

from vllm.config import VllmConfig
28
from vllm.config.multimodal import BaseDummyOptions
29
30
31
32
33
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
34
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
35
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
42
43
44
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
65
66
67
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
68
from vllm.utils.tensor_schema import TensorSchema, TensorShape
69
70
71

from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
72
73
74
75
76
77
78
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
79
80
81
82

_AUDIO_MAX_SOUNDFILE_SIZE = 241_000


83
84
85
def _get_padding_size(
    orig_width: int, orig_height: int, target_height: int, target_width: int
):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    ratio_width = target_width / orig_width
    ratio_height = target_height / orig_height

    if ratio_width < ratio_height:
        padding_width = 0
        padding_height = target_height - int(orig_height * ratio_width)
    else:
        padding_width = target_width - int(orig_width * ratio_height)
        padding_height = 0
    return padding_height, padding_width


class Phi4MMProjector(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.up = ColumnParallelLinear(input_size, hidden_size)
        self.down = RowParallelLinear(hidden_size, hidden_size)
        self.act = get_act_fn("gelu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up(x)
        x = self.act(x)
        x, _ = self.down(x)
        return x


class Phi4MMImageEmbedding(nn.Module):
    """Image embedding."""

    def __init__(self, config: Phi4MultimodalConfig):
        super().__init__()
        self.config = config
        self.layer_idx = config.vision_config.feature_layer
        self.crop_size = config.vision_config.crop_size
        self.image_dim_out = config.vision_config.hidden_size

122
        n_patches = config.vision_config.image_size // config.vision_config.patch_size
123
124
125
        if n_patches % 2 != 0:
            self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
            n_patches += 1
126
        self.num_img_tokens = (n_patches // 2) ** 2
127

128
129
130
131
132
        num_hidden_layers = (
            config.vision_config.num_hidden_layers + self.layer_idx + 1
            if self.layer_idx < 0
            else self.layer_idx + 1
        )
133
134
135
        self.img_processor = Idefics2VisionTransformer(
            config.vision_config,
            require_post_norm=False,
136
137
            num_hidden_layers_override=num_hidden_layers,
        )
138
        self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
139
        self.img_projection = Phi4MMProjector(self.image_dim_out, config.hidden_size)
140
        self.global_img_feature_extensor = nn.Parameter(
141
142
            torch.zeros([1, 1, self.image_dim_out])
        )
143
        self.sub_img_feature_extensor = nn.Parameter(
144
145
            torch.zeros([1, 1, 1, self.image_dim_out])
        )
146
147
148
149
150
151

    def get_img_features(
        self,
        img_embeds: torch.FloatTensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
152
153
154
        img_feature = self.img_processor(
            img_embeds, patch_attention_mask=attention_mask
        )
155
156
157
158

        patch_feature = img_feature
        # reshape to 2D tensor
        width = int(math.sqrt(patch_feature.size(1)))
159
        patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
160
161
162
163
164
165
166
167
        # convert to NCHW
        patch_feature = patch_feature.permute(0, 3, 1, 2)
        if getattr(self, "img_processor_padding", None) is not None:
            patch_feature = self.img_processor_padding(patch_feature)
        patch_feature = self.image_token_compression(patch_feature)
        # convert to NHWC
        patch_feature = patch_feature.permute(0, 2, 3, 1)
        patch_feature = patch_feature.view(
168
169
            -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)
        )
170
171
172
173
174
175
176
177
178
        return patch_feature

    def forward(
        self,
        image_pixel_values: torch.FloatTensor,
        image_sizes: Optional[torch.Tensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        image_pixel_values = image_pixel_values.to(
179
180
            self.img_processor.embeddings.patch_embedding.weight.dtype
        )
181
182
183
184
185
186
187
188
189

        target_device = self.img_projection.up.bias.device
        target_dtype = self.img_projection.up.bias.dtype

        batch_size = image_pixel_values.shape[0]

        img_features = self.get_img_features(
            image_pixel_values.flatten(0, 1),
            attention_mask=image_attention_mask.flatten(0, 1).to(
190
191
                dtype=bool, device=target_device
            ),
192
193
        )
        base_feat_size = int(np.sqrt(img_features.shape[1]))
194
195
196
        img_features = img_features.view(
            batch_size, -1, base_feat_size**2, self.image_dim_out
        )
197
198
199
200
201
202
203
204
205
206
        image_sizes = image_sizes.view(-1, 2)

        output_imgs = []
        for idx in range(batch_size):
            height, width = image_sizes[idx]
            height_ratio = height // self.crop_size
            width_ratio = width // self.crop_size
            area_ratio = height_ratio * width_ratio

            global_img = img_features[idx, :1]
207
208
209
            global_img = global_img.reshape(
                1, base_feat_size, base_feat_size, self.image_dim_out
            ).contiguous()
210
            temporary_extensor = self.sub_img_feature_extensor.repeat(
211
212
213
214
215
                1, base_feat_size, 1, 1
            )
            global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(
                1, -1, self.image_dim_out
            )
216
217
218

            sub_img = img_features[idx, 1:]
            sub_img = sub_img[:area_ratio]
219
220
221
222
223
224
225
226
227
228
229
230
            sub_img = (
                sub_img.reshape(
                    height_ratio,
                    width_ratio,
                    base_feat_size,
                    base_feat_size,
                    self.image_dim_out,
                )
                .transpose(1, 2)
                .reshape(
                    1,
                    height_ratio * base_feat_size,
231
                    width_ratio * base_feat_size,
232
233
234
235
                    self.image_dim_out,
                )
                .contiguous()
            )
236
237
238

            if image_attention_mask is not None:
                reshaped_image_attention_mask = (
239
240
241
242
243
244
245
246
247
                    image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2]
                    .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size)
                    .transpose(1, 2)
                    .reshape(
                        1, height_ratio * base_feat_size, width_ratio * base_feat_size
                    )
                )
                useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
                useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
248
249
                sub_img = sub_img[:, :useful_height, :useful_width]
                temporary_extensor = self.sub_img_feature_extensor.repeat(
250
251
                    1, useful_height, 1, 1
                )
252
253
            else:
                temporary_extensor = self.sub_img_feature_extensor.repeat(
254
255
                    1, height_ratio * base_feat_size, 1, 1
                )
256

257
258
259
            sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(
                1, -1, self.image_dim_out
            )
260
261
262
263

            # Merge global and sub
            output_imgs.append(
                torch.cat(
264
265
266
                    [sub_img, self.global_img_feature_extensor, global_img], dim=1
                )
            )
267
268
269

        img_set_tensor = []
        for output_img in output_imgs:
270
            output_img = output_img.to(device=target_device, dtype=target_dtype)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            img_feature_proj = self.img_projection(output_img)
            img_set_tensor.append(img_feature_proj.flatten(0, 1))

        return img_set_tensor


class Phi4MultimodalAudioMLP(nn.Module):
    def __init__(
        self,
        config: Phi4MultimodalAudioConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.act_fn = MulAndSilu()
        self.gate_up_proj = MergedColumnParallelLinear(
288
289
290
291
292
293
294
295
296
            config.hidden_size,
            [config.intermediate_size] * 2,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
297
298
            bias=True,
            quant_config=quant_config,
299
300
            prefix=f"{prefix}.down_proj",
        )
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.layer_norm(hidden_states)
        hidden_states, _ = self.gate_up_proj(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states


class Phi4MultimodalAudioAttention(nn.Module):
    def __init__(
        self,
        config: Phi4MultimodalAudioConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads
        if self.head_dim * self.total_num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
326
327
                f" {self.num_heads})."
            )
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.num_heads = divide(self.total_num_heads, self.tp_size)

    def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        start_idx = self.num_heads * self.tp_rank
        end_idx = self.num_heads * (self.tp_rank + 1)
        return attention_mask[:, start_idx:end_idx]

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        qkv_states, _ = self.qkv_proj(hidden_states)
        query, key, value = qkv_states.chunk(3, dim=-1)

        bsz, seq_len, _ = query.size()
        query = query.view(bsz, seq_len, self.num_heads, self.head_dim)
        key = key.view(bsz, seq_len, self.num_heads, self.head_dim)
        value = value.view(bsz, seq_len, self.num_heads, self.head_dim)
        query, key, value = (x.transpose(1, 2) for x in (query, key, value))

        attention_mask = self.split_attn_mask(attention_mask)
        out = F.scaled_dot_product_attention(
            query,
            key,
            value,
            scale=self.scale,
            attn_mask=attention_mask,
        )
        out = out.transpose(1, 2).reshape(bsz, seq_len, -1)

        attn_output, _ = self.o_proj(out)

        return attn_output


class Phi4MultimodalAudioConformerEncoderLayer(nn.Module):
    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()

        self.feed_forward_in = Phi4MultimodalAudioMLP(config)
        self.self_attn = Phi4MultimodalAudioAttention(config)
        self.conv = Phi4MultimodalAudioConvModule(config)
        self.feed_forward_out = Phi4MultimodalAudioMLP(config)
        self.layer_norm_att = nn.LayerNorm(config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states)
        hidden_states = self.layer_norm_att(residual)

402
        hidden_states = residual + self.self_attn(hidden_states, attention_mask)
403
        hidden_states = hidden_states + self.conv(hidden_states)
404
        hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states)
405
406
407
408
409
410
411
412
413
414
415
416
417

        out = self.layer_norm(hidden_states)

        return out


class Phi4MMAudioMeanVarianceNormLayer(nn.Module):
    """Mean/variance normalization layer.

    Will subtract mean and multiply input by inverted standard deviation.
    Typically used as a very first layer in a model.

    Args:
418
        config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig)
419
            object containing model parameters.
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    """

    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()
        self.global_mean = nn.Parameter(torch.zeros(config.input_size))
        self.global_invstd = nn.Parameter(torch.ones(config.input_size))

    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        """MeanVarianceNormLayer Forward

        Args:
            input_: torch.Tensor
                input tensor.
        """
        return (input_ - self.global_mean) * self.global_invstd


class Phi4MultimodalAudioModel(nn.Module):
    def __init__(self, config: Phi4MultimodalAudioConfig):
        super().__init__()
        self.config = config

        self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config)
        self.embed = Phi4MultimodalAudioNemoConvSubsampling(config)
444
445
446
447
448
449
450
451
452
        self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(
            config
        )
        self.encoders = nn.ModuleList(
            [
                Phi4MultimodalAudioConformerEncoderLayer(config)
                for _ in range(config.num_blocks)
            ]
        )
453
454
455
456
457
458
459
460
461
462
463
464

    def _streaming_mask(
        self,
        seq_len: int,
        batch_size: int,
        chunk_size: int,
        left_chunk: int,
    ):
        # Create mask matrix for streaming
        # S stores start index. if chunksize is 18, s is [0,18,36,....]
        chunk_start_idx = np.arange(0, seq_len, chunk_size)

465
466
467
468
469
        enc_streaming_mask = (
            adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk)
            .unsqueeze(0)
            .expand([batch_size, -1, -1])
        )
470
471
472
473
474
475
476
477
        return enc_streaming_mask

    def forward_embeddings(
        self,
        hidden_states: torch.Tensor,
        masks: torch.Tensor,
    ):
        """Forwarding the inputs through the top embedding layers"""
478
        seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction)
479
480
481
        if seq_len <= 0:
            raise ValueError(
                f"Sequence length after time reduction is invalid: {seq_len}."
482
483
                "Your input feature is too short."
            )
484
485
486

        batch_size = hidden_states.shape[0]

487
488
489
        enc_streaming_mask = self._streaming_mask(
            seq_len, batch_size, self.config.chunk_size, self.config.left_chunk
        )
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        enc_streaming_mask = enc_streaming_mask.to(hidden_states.device)

        hidden_states, masks = self.embed(hidden_states, masks)

        streaming_mask = enc_streaming_mask
        if streaming_mask is not None and masks is not None:
            hs_mask = masks & streaming_mask
        elif masks is not None:
            hs_mask = masks
        else:
            hs_mask = streaming_mask

        return hidden_states, hs_mask, masks

504
505
506
    def calculate_hs_mask(
        self, hidden_states: torch.Tensor, device: torch.device, mask: torch.Tensor
    ):
507
508
        max_audio_length = hidden_states.shape[1]
        batch_size = hidden_states.shape[0]
509
510
511
        enc_streaming_mask = self._streaming_mask(
            max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk
        )
512
513
514
515
516
517
518
        enc_streaming_mask = enc_streaming_mask.to(device)
        if mask is None:
            return enc_streaming_mask

        feature_lens = mask.sum(1)
        padding_length = feature_lens
        pad_mask = torch.arange(0, max_audio_length, device=device).expand(
519
520
            padding_length.size(0), -1
        ) < padding_length.unsqueeze(1)
521
522
523
524
        pad_mask = pad_mask.unsqueeze(1)
        pad_mask = pad_mask & enc_streaming_mask
        return pad_mask

525
    def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None):
526
        hidden_states = self.encoder_embedding(hidden_states)
527
        hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask)
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542

        unfolded = False
        bs, seq_len, _ = hidden_states.shape
        max_seq_len = 500  # maximum position for absolute positional encoding
        if seq_len > max_seq_len:
            # audio sequence is longer than max_seq_len,
            # unfold it into chunks of max_seq_len
            unfolded = True
            # the unfold op will drop residual frames,
            # pad it to the multiple of max_seq_len
            if seq_len % max_seq_len > 0:
                chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
            else:
                chunk_pad_size = 0
            if chunk_pad_size > 0:
543
544
545
                hidden_states_pad = F.pad(
                    hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0
                )
546
547
548
549
550
551
552
                hidden_states = hidden_states_pad.to(hidden_states.device)

            hidden_states = unfold_tensor(hidden_states, max_seq_len)
            masks_unfold = None
            if mask is not None:
                # revise hs_mask here because the previous calculated hs_mask
                # did not consider extra pad
553
                subsampled_pad_mask = mask.squeeze(1)  # [bz, subsampled_unmask_seq_len]
554
                extra_padded_subsamlped_pad_mask = F.pad(
555
556
                    subsampled_pad_mask, (0, chunk_pad_size), "constant", False
                )  # extra padding to the pad mask
557
                extra_padded_subsamlped_pad_mask = (
558
559
                    extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
                )
560
561
562
563
                masks_unfold = unfold_tensor(
                    extra_padded_subsamlped_pad_mask, max_seq_len
                )  # unfold the pad mask like we did to the input tensor
                masks_unfold = masks_unfold.squeeze(
564
565
                    -1
                ).bool()  # unfold op does not support bool tensor
566
567
568
569
            hs_mask = self.calculate_hs_mask(
                hidden_states, hidden_states.device, masks_unfold
            )  # calculate hs_mask based on the unfolded pad mask

570
        relative_attention_bias = self.relative_attention_bias_layer(hidden_states)
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias

        for layer in self.encoders:
            hidden_states = layer(hidden_states, attention_mask)

        if unfolded:
            embed_dim = hidden_states.shape[-1]
            hidden_states = hidden_states.reshape(bs, -1, embed_dim)
            # if we ever padded before unfolding, we need to remove the padding
            if chunk_pad_size > 0:
                hidden_states = hidden_states[:, :-chunk_pad_size, :]

        return hidden_states


class Phi4MMAudioEmbedding(nn.Module):
    def __init__(self, config: Phi4MultimodalConfig):
        super().__init__()
        self.config = config
        self.layer_idx = config.audio_config.feature_layer

        self.encoder = Phi4MultimodalAudioModel(config.audio_config)

        audio_config = config.audio_config
595
        proj_input_size = audio_config.hidden_size * audio_config.downsample_rate
596
        self.vision_speech_projection = Phi4MMProjector(
597
598
599
            proj_input_size, config.hidden_size
        )
        self.speech_projection = Phi4MMProjector(proj_input_size, config.hidden_size)
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

    def get_projection(
        self,
        audio_projection_mode: Literal["speech", "vision"],
    ) -> Phi4MMProjector:
        if audio_projection_mode == "speech":
            return self.speech_projection
        elif audio_projection_mode == "vision":
            return self.vision_speech_projection

    def forward(
        self,
        audio_input_features: torch.FloatTensor,
        audio_embed_sizes=None,
        audio_attention_mask=None,
        audio_projection_mode="speech",
    ) -> torch.FloatTensor:
        audio_projection = self.get_projection(audio_projection_mode)

        target_device = audio_projection.up.bias.device
        target_dtype = audio_projection.up.bias.dtype

622
623
624
        audio_input_features = audio_input_features.to(
            device=target_device, dtype=target_dtype
        )
625

626
627
628
        audio_encoder_hidden_states = self.encoder(
            audio_input_features, audio_attention_mask
        )
629
630
631
632
        audio_embeds = audio_projection(audio_encoder_hidden_states)

        return audio_embeds.flatten(0, 1)

633
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
654
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
655
656
657
658
659
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


660
class Phi4MMImagePixelInputs(TensorSchema):
661
    """
662
663
664
665
666
667
668
669
670
    Dimensions:
        - bn: Batch size * number of images
        - p: Number of patches (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
        - nc: Number of crops
        - H_mask: Height of attention mask
        - W_mask: Width of attention mask
671
672
    """

673
    type: Literal["pixel_values"]
674

675
676
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
677
678
679
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # may be different per batch and image
680
    ]
681

682
683
684
685
    image_sizes: Annotated[
        torch.Tensor,
        TensorShape("bn", 2),  # (height, width)
    ]
686

687
688
689
690
    num_img_tokens: Annotated[
        list[int],
        TensorShape("bn"),
    ]
691

692
693
694
695
    image_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("bn", "nc", 32, 32),  # H_mask, W_mask
    ]
696
697


698
class Phi4MMImageEmbeddingInputs(TensorSchema):
699
    """
700
701
702
703
704
705
706
707
708
709
710
711
712
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match language model backbone)
    """

    type: Literal["image_embeds"]

    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "f", "h"),
    ]

713

714
715
716
717
718
719
720
class Phi4MMAudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of audios
        - f: Number of Mel filterbank bins (80)
        - t: Time frames (M)
    """
721
722
723

    type: Literal["audio_features"]

724
725
726
727
728
729
730
731
732
733
734
735
736
737
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "t", 80, dynamic_dims={"t"}),
    ]


class Phi4MMAudioEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - n: Number of audios
        - f: Audio feature size
        - h: Hidden size (must match language model backbone)
    """
738
739

    type: Literal["audio_embeds"]
740
741
742
743
744

    data: Annotated[
        NestedTensors,
        TensorShape("b", "n", "f", "h"),
    ]
745
746
747
748
749
750
751
752
753
754
755


Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]


def cat_with_pad(tensors, dim, padding_value=0):
    """
    cat along dim, while pad to max for all other dims
    """
    ndim = tensors[0].dim()
756
757
758
    assert all(t.dim() == ndim for t in tensors[1:]), (
        "All tensors must have the same number of dimensions"
    )
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780

    out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
    out_size[dim] = sum(t.shape[dim] for t in tensors)
    output = tensors[0].new_full(out_size, padding_value)

    index = 0
    for t in tensors:
        # Create a slice list where every dimension except dim is full slice
        slices = [slice(0, t.shape[d]) for d in range(ndim)]
        # Update only the concat dimension slice
        slices[dim] = slice(index, index + t.shape[dim])

        output[slices] = t
        index += t.shape[dim]

    return output


class Phi4MMProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> Phi4MultimodalConfig:
        return self.ctx.get_hf_config(Phi4MultimodalConfig)

781
782
    def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor:
        return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs)
783

784
    def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtractor:
785
        return self.get_hf_processor(**kwargs).audio_processor
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817

    def get_image_processor(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> Phi4MultimodalImageProcessorFast:
        if processor is None:
            processor = self.get_hf_processor()
        return processor.image_processor

    def get_dynamic_hd(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> int:
        return self.get_image_processor(processor).dynamic_hd

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None, "image": None}

    def _find_target_aspect_ratio(
        self,
        orig_width: int,
        orig_height: int,
        image_size: int,
        max_num: int,
        min_num: int,
    ):
        w_crop_num = math.ceil(orig_width / float(image_size))
        h_crop_num = math.ceil(orig_height / float(image_size))
        if w_crop_num * h_crop_num > max_num:
            aspect_ratio = orig_width / orig_height

            # calculate the existing image aspect ratio
818
819
820
821
822
823
            target_ratios = set(
                (i, j)
                for i in range(1, max_num + 1)
                for j in range(1, max_num + 1)
                if i * j <= max_num and i * j >= min_num
            )
824
825
826
827
828
829
830
831
832
            target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

            # find the closest aspect ratio to the target
            image_processor = self.get_image_processor()
            target_aspect_ratio = image_processor.find_closest_aspect_ratio(
                aspect_ratio,
                target_ratios,
                orig_width,
                orig_height,
833
                image_size,
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
            )

            # calculate the target width and height
            target_width = image_size * target_aspect_ratio[0]
            target_height = image_size * target_aspect_ratio[1]
        else:
            target_width = image_size * w_crop_num
            target_height = image_size * h_crop_num
            target_aspect_ratio = (w_crop_num, h_crop_num)
        return target_aspect_ratio, target_height, target_width

    def _compute_num_image_tokens(
        self,
        orig_width: int,
        orig_height: int,
        dynamic_hd_size: int,
        vit_image_size: int,
        vit_patch_size: int,
        token_compression_factor: int = 2,
    ):
        """
        compute the number of tokens an image is expected to take up considering
856
        the image encoder architecture and exclude output features containing
857
858
        only padding pixels

859
        for siglip, vit_image_size=448, vit_patch_size=14, so output will be
860
861
862
863
        32x32 feature map
        NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
        """
        assert vit_image_size % vit_patch_size == 0, (
864
865
866
867
868
869
            "vit_image_size must be divisible by vit_patch_size"
        )
        assert vit_image_size // vit_patch_size % token_compression_factor == 0, (
            "vit_image_size // vit_patch_size must be divisible by "
            "token_compression_factor"
        )
870
871

        target_aspect_ratio, target_height, target_width = (
872
873
874
875
            self._find_target_aspect_ratio(
                orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1
            )
        )
876
        assert target_aspect_ratio[0] * vit_image_size == target_width, (
877
878
            f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
        )
879
        assert target_aspect_ratio[1] * vit_image_size == target_height, (
880
881
882
883
884
            f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
        )
        assert (
            target_height % vit_image_size == 0 and target_width % vit_image_size == 0
        )
885
886

        padding_height, padding_width = _get_padding_size(
887
888
889
            orig_width, orig_height, target_height, target_width
        )
        assert padding_width == 0 or padding_height == 0, (
890
            "padding_width or padding_height must be 0"
891
        )
892
893
894
895
896
897

        target_feat_width = target_width // vit_patch_size
        target_feat_height = target_height // vit_patch_size
        if padding_width >= vit_patch_size:
            assert padding_height == 0, "padding_height not 0"
            non_pad_feat_width = target_feat_width - math.floor(
898
899
                padding_width / vit_patch_size
            )
900
901
902
903
            non_pad_feat_height = target_feat_height
        elif padding_height >= vit_patch_size:
            assert padding_width == 0, "padding_width not 0"
            non_pad_feat_height = target_feat_height - math.floor(
904
905
                padding_height / vit_patch_size
            )
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
            non_pad_feat_width = target_feat_width
        else:
            # small padding shorter than a vit patch
            non_pad_feat_width = target_feat_width
            non_pad_feat_height = target_feat_height

        feat_width = non_pad_feat_width // token_compression_factor
        feat_height = non_pad_feat_height // token_compression_factor
        # NOTE it's possible that the non-padding feature is not divisible
        if non_pad_feat_width % token_compression_factor != 0:
            feat_width += 1
        if non_pad_feat_height % token_compression_factor != 0:
            feat_height += 1
        num_hd_patch_tokens = feat_width * feat_height
        num_hd_newline_tokens = feat_height
        vit_feature_size = vit_image_size // vit_patch_size
922
        num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2
923
        num_sep_tokens = 1
924
925
926
927
928
929
930
931
932
        num_global_image_newline_tokens = vit_feature_size // token_compression_factor

        return (
            num_global_image_tokens
            + num_sep_tokens
            + num_hd_patch_tokens
            + num_hd_newline_tokens
            + num_global_image_newline_tokens
        )
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        vit_image_size = vision_config.image_size
        vit_patch_size = vision_config.patch_size

        dynamic_hd_size = self.get_dynamic_hd(processor=processor)

        # we use default `token_compression_factor=2`,
        # since it's not in HF vision config.
        image_num_tokens = self._compute_num_image_tokens(
            image_width,
            image_height,
            dynamic_hd_size=dynamic_hd_size,
            vit_image_size=vit_image_size,
            vit_patch_size=vit_patch_size,
        )

        return image_num_tokens

    def get_image_size_with_most_features(
        self,
        processor: Optional[Phi4MMProcessor] = None,
    ) -> ImageSize:
        vit_image_size = self.get_hf_config().vision_config.image_size

        max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
        return ImageSize(height=max_side, width=vit_image_size)

    def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
        """
        Compute the output size of the `extract_features` method.

        Args:
            audio_len (int): Length of the input waveform in samples.
            sr (float): Sampling rate of the waveform, either 16000 or 8000.

        Returns:
            tuple (int, int): Output size as (T, D), where:
                T: Number of time frames.
                D: Number of Mel filterbank bins (80).
        """

        # Resample to 16000 or 8000 if needed
        if sr > 16000:
            audio_len //= sr // 16000
        elif 8000 <= sr < 16000:
            # We'll resample to 16K from 8K
            audio_len *= 2
        elif sr < 8000:
            raise RuntimeError(f"Unsupported sample rate {sr}")

        # Spectrogram parameters for 16 kHz
        win_length = 400  # Frame length in samples
        hop_length = 160  # Frame shift in samples

        # Calculate number of frames (T)
        num_frames = (audio_len - win_length) // hop_length + 1
        if num_frames < 1:
            raise ValueError("Waveform too short for given parameters.")

        # Return time frames (T)
        return num_frames

    def _compute_audio_embed_size(self, audio_frames: int) -> int:
        """
        Compute the size of audio embeddings from the number of audio frames.
        """
        # `_compute_audio_embed_size` in audio_processor use torch for
        # computation, therefore we re-implement it to use pythonic
        # numeric computation to avoid extra tensor conversion.
        audio_processor = self.get_feature_extractor()
        audio_compression_rate = audio_processor.audio_compression_rate
        audio_downsample_rate = audio_processor.audio_downsample_rate

        integer = audio_frames // audio_compression_rate
        remainder = audio_frames % audio_compression_rate
        result = integer + int(remainder > 0)

        integer = result // audio_downsample_rate
        remainder = result % audio_downsample_rate
        result = integer + int(remainder > 0)  # qformer compression

        return result


class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)

        tokenizer = self.info.get_tokenizer()
        image_tokens: str = tokenizer.image_token * num_images
        audio_tokens: str = tokenizer.audio_token * num_audios

        return image_tokens + audio_tokens

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1041
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
1042
1043
1044
1045
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)

1046
        target_width, target_height = self.info.get_image_size_with_most_features()
1047

1048
1049
1050
        image_overrides = mm_options.get("image") if mm_options else None
        audio_overrides = mm_options.get("audio") if mm_options else None

1051
        mm_data = {
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "audio": self._get_dummy_audios(
                length=_AUDIO_MAX_SOUNDFILE_SIZE,
                num_audios=num_audios,
                overrides=audio_overrides,
            ),
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        }

        return mm_data


class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        audio_data = mm_data.pop("audios", [])
        if audio_data:
1087
            mm_data["audio"] = audio_data
1088

1089
1090
1091
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
1092
1093
1094

        if "image_pixel_values" in processed_outputs:
            num_img_tokens = [
1095
1096
1097
                self.info.get_num_image_tokens(
                    image_width=img_size[0], image_height=img_size[1]
                )
1098
1099
1100
1101
1102
                for img_size in processed_outputs["image_sizes"]
            ]
            processed_outputs["num_img_tokens"] = num_img_tokens

        if audio_data:
1103
            audio_features = processed_outputs["audio_input_features"]
1104
            sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
1105
            feature_sizes = [
1106
                self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data
1107
            ]
1108
1109
            processed_outputs["audio_input_features"] = [
                audio_features[idx, :size] for idx, size in enumerate(feature_sizes)
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
            ]

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            image_pixel_values=MultiModalFieldConfig.batched("image"),
            image_attention_mask=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            num_img_tokens=MultiModalFieldConfig.batched("image"),
            audio_input_features=MultiModalFieldConfig.batched("audio"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1131
        out_mm_kwargs: MultiModalKwargsItems,
1132
1133
    ) -> Sequence[PromptUpdate]:
        tokenizer = self.info.get_tokenizer()
1134
1135
        image_token_id: int = tokenizer.vocab[tokenizer.image_token]
        audio_token_id: int = tokenizer.vocab[tokenizer.audio_token]
1136
1137

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1138
        audio_processor = self.info.get_feature_extractor(**hf_processor_mm_kwargs)
1139
1140
1141

        def get_image_replacement_phi4mm(item_idx: int):
            images = mm_items.get_items(
1142
1143
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

1155
            return [image_token_id] * num_image_tokens
1156
1157
1158
1159
1160
1161

        def get_audio_replacement_phi4mm(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            # TODO(Isotr0py): support embedding inputs
            audio_len = audios.get_audio_length(item_idx)
            audio_frames = self.info.get_audio_num_frames(
1162
1163
1164
                audio_len, audio_processor.sampling_rate
            )
            audio_embed_size = self.info._compute_audio_embed_size(audio_frames)
1165

1166
            return [audio_token_id] * audio_embed_size
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190

        return [
            PromptReplacement(
                modality="audio",
                target=[audio_token_id],
                replacement=get_audio_replacement_phi4mm,
            ),
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_image_replacement_phi4mm,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Phi4MMMultiModalProcessor,
    info=Phi4MMProcessingInfo,
    dummy_inputs=Phi4MMDummyInputsBuilder,
)
class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
    """
    Implements the Phi-4-multimodal-instruct model in vLLM.
    """
1191

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
    packed_modules_mapping = {
        "qkv_proj": [
            "qkv_proj",
        ],
        "gate_up_proj": [
            "gate_up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Multimodal embedding
            "model.embed_tokens_extend.": "",
            # LLM backbone
            "model.": "language_model.model.",
        },
        orig_to_new_substr={
            # projection
            ".img_projection_": ".img_projection.",
            ".up_proj_for_speech.": ".speech_projection.up.",
            ".up_proj_for_vision_speech.": ".vision_speech_projection.up.",
            ".down_proj_for_speech.": ".speech_projection.down.",
            ".down_proj_for_vision_speech.": ".vision_speech_projection.down.",
        },
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|image|>"
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only image or audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        # TODO: Optionally initializes these for supporting input embeddings.
        self.image_embed = Phi4MMImageEmbedding(
            config,
            # prefix=maybe_prefix(prefix, "image_embed"),
        )
        self.audio_embed = Phi4MMAudioEmbedding(
            config,
            # prefix=maybe_prefix(prefix, "audio_embed"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Phi3ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
1251
1252
            self.language_model.make_empty_intermediate_tensors
        )
1253
1254

    def _parse_and_validate_audio_input(
1255
1256
        self, **kwargs: object
    ) -> Optional[Phi4MMAudioInputs]:
1257
        """
1258
        Parse and validate the audio input to the model.  This handles both
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
        audio features and audio embeddings, but only the former is used for
        now.

        Args:
            kwargs (object): Keyword arguments.

        Returns:
            Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
        """
        audio_features = kwargs.pop("audio_input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
1275
1276
1277
            return Phi4MMAudioFeatureInputs(
                type="audio_features", data=flatten_bn(audio_features)
            )
1278
1279

        if audio_embeds is not None:
1280
            return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
1281
1282
1283

        raise AssertionError("This line should be unreachable.")

1284
1285
1286
    def _process_audio_input(
        self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str
    ) -> NestedTensors:
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
        """
        Create the audio embeddings from the audio input, where the audio input
        is pairs of audio features and audio embed lengths.  The audio input is
        created by `input_mapper_for_phi4mm_audio`.

        Args:
            audio_input (Phi4MMAudioInputs): Audio input.

        Returns:
            NestedTensors: Audio embeddings
        """
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
        # (e.g. multiple examples) and the second dim is the multi-audio dim
        # (e.g. multiple audios in the same example)

        dtype = next(self.audio_embed.parameters()).dtype
        audio_embeds = [
            self.audio_embed(
                features.unsqueeze(0).to(dtype),
                audio_projection_mode=audio_projection_mode,
1310
1311
            )
            for features in audio_features
1312
1313
1314
1315
        ]
        return audio_embeds

    def _parse_and_validate_image_input(
1316
1317
        self, **kwargs: object
    ) -> Optional[Phi4MMImagePixelInputs]:
1318
1319
1320
1321
1322
1323
1324
        image_pixel_values: NestedTensors = kwargs.get("image_pixel_values")
        if image_pixel_values is None:
            return None

        image_sizes = kwargs.get("image_sizes")
        image_attention_mask = kwargs.get("image_attention_mask")
        num_img_tokens = kwargs.get("num_img_tokens")
1325
1326
1327
1328
1329
        assert (
            image_sizes is not None
            and image_attention_mask is not None
            and num_img_tokens is not None
        ), "Missing image inputs"
1330
1331

        if is_list_of(image_pixel_values, torch.Tensor):
1332
1333
1334
            assert all(p.dim() == 5 for p in image_pixel_values), (
                "Incorrect image inputs"
            )
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
            # list len is batch_size.
            # each tensor has dimension: num_img_per_example, num_hd_patches,
            # channels, height, width.
            # need to pad along num_hd_patches.
            # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
            image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
        elif isinstance(image_pixel_values, torch.Tensor):
            # dimension: batch_size, num_img_per_example, num_hd_patches,
            # channels, height, width.
            # we flatten first 2 dims to make it a single large batch for
            # SigLIP Encoder.
            assert image_pixel_values.dim() == 6, "Incorrect image inputs"
            image_pixel_values = image_pixel_values.flatten(0, 1)
        else:
            raise ValueError("Incorrect image_pixel_values inputs")

        if isinstance(image_attention_mask, list):
            image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
        elif isinstance(image_attention_mask, torch.Tensor):
            image_attention_mask = image_attention_mask.flatten(0, 1)
        else:
            raise ValueError("Incorrect image_attention_mask inputs")

        if isinstance(image_sizes, list):
            image_sizes = torch.cat(image_sizes, dim=0)
        elif isinstance(image_sizes, torch.Tensor):
            image_sizes = image_sizes.flatten(0, 1)
        else:
1363
            raise ValueError("Incorrect image_sizes inputs")
1364
1365
1366

        if isinstance(num_img_tokens, list):
            num_img_tokens = [
1367
                n for num_tensor in num_img_tokens for n in num_tensor.tolist()
1368
1369
1370
1371
            ]
        elif isinstance(num_img_tokens, torch.Tensor):
            num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
        else:
1372
            raise ValueError("Incorrect num_img_tokens inputs")
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387

        return Phi4MMImagePixelInputs(
            type="pixel_values",
            data=image_pixel_values,
            image_sizes=image_sizes,
            image_attention_mask=image_attention_mask,
            num_img_tokens=num_img_tokens,
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
            if (
                input_key in ("image_pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("audio_input_features", "audio_embeds")
                and "audios" not in modalities
            ):
                modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
1398
1399
1400
1401

        return modalities

    def _process_image_input(
1402
1403
        self, image_input: Phi4MMImagePixelInputs
    ) -> list[torch.Tensor]:
1404
1405
1406
1407
        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            dtype = next(self.image_embed.parameters()).dtype
1408
1409
1410
1411
1412
1413
            pixel_values = image_input["data"].to(dtype)
            image_sizes = image_input["image_sizes"]
            image_attention_mask = image_input["image_attention_mask"]
            image_embeds = self.image_embed(
                pixel_values, image_sizes, image_attention_mask
            )
1414
1415
        return image_embeds

1416
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1417
1418
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1419
            return []
1420
1421

        # The result multimodal_embeddings is tuple of tensors, with each
1422
        # tensor corresponding to a multimodal data item (image or video).
1423
1424
1425
1426
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
1427
        audio_projection_mode = "speech"
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
        for modality in modalities:
            # make sure process images first
            if modality == "images":
                audio_projection_mode = "vision"
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(vision_embeddings)
            if modality == "audios":
                audio_input = modalities["audios"]
                audio_embeddings = self._process_audio_input(
1438
1439
                    audio_input, audio_projection_mode=audio_projection_mode
                )
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
                multimodal_embeddings += tuple(audio_embeddings)

        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1468
        return self.language_model.compute_logits(hidden_states)
1469

1470
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector=[
1481
1482
1483
                "img_projection",
                "vision_speech_projection",
                "speech_projection",
1484
1485
1486
1487
1488
1489
            ],
            tower_model=["image_embed", "audio_embed"],
        )

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model