phi4mm.py 44.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import math
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Any, Literal, TypeAlias
6
7
8
9

import numpy as np
import torch
import torch.nn as nn
10
11
12
13
14
15
16
from transformers import (
    BatchFeature,
    PretrainedConfig,
    ProcessorMixin,
    SequenceFeatureExtractor,
    SiglipVisionConfig,
)
17
18

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
from vllm.distributed import get_pp_group
21
22
23
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
)
27
from vllm.model_executor.models.llama import LlamaModel
28
from vllm.model_executor.models.module_mapping import MultiModelKeys
29
from vllm.multimodal import MULTIMODAL_REGISTRY
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
51
52
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
53
from vllm.utils.tensor_schema import TensorSchema, TensorShape
54

55
from .idefics2_vision_model import Idefics2VisionTransformer
56
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
57
from .phi4mm_audio import AudioEmbedding
58
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
59
60
61
62
63
64
65
66
67
68

# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
# <|endoftext11|>
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011

_AUDIO_MAX_SOUNDFILE_SIZE = 241_000

SIGLIP_NAME = "siglip-so400m-patch14-448"
VISION_ENCODER_TO_PROCESSING_CONFIG = {
69
70
71
72
    "siglip-so400m-patch14-448": {
        "vit_image_size": 448,
        "vit_patch_size": 14,
        "token_compression_factor": 2,
73
74
75
76
    },
}


77
78
79
def _get_padding_size(
    orig_width: int, orig_height: int, target_height: int, target_width: int
):
80
81
82
83
84
85
86
87
88
89
90
91
    ratio_width = target_width / orig_width
    ratio_height = target_height / orig_height

    if ratio_width < ratio_height:
        padding_width = 0
        padding_height = target_height - int(orig_height * ratio_width)
    else:
        padding_width = target_width - int(orig_width * ratio_height)
        padding_height = 0
    return padding_height, padding_width


92
93
94
95
96
97
98
99
100
101
102
103
104
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
    vision_config = {
        "hidden_size": 1152,
        "image_size": 448,
        "intermediate_size": 4304,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_hidden_layers": 27,
        "patch_size": 14,
    }

    model_config = SiglipVisionConfig(**vision_config, **kwargs)
    if layer_idx < 0:
105
        num_hidden_layers = model_config.num_hidden_layers + layer_idx + 1
106
107
108
109
110
111
112
113
114
115
116
117
    else:
        num_hidden_layers = layer_idx + 1

    vision_model = Idefics2VisionTransformer(
        config=model_config,
        require_post_norm=False,
        num_hidden_layers_override=num_hidden_layers,
    )

    return vision_model


118
119
120
class Phi4MMImageEncoder(nn.Module):
    """Image embedding."""

121
122
123
    def __init__(
        self,
        config: PretrainedConfig,
124
        quant_config: QuantizationConfig | None,
125
126
127
        prefix: str = "",
        model_dir: str = "",
    ) -> None:
128
129
130
        super().__init__()

        # n_embed or hidden_size
131
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
132
133
134

        # layer_idx to output the img features
        if isinstance(config.img_processor, dict):
135
136
            self.layer_idx = config.img_processor.get("layer_idx", -2)
            self.type_feature = config.img_processor.get("type_feature", "patch")
137
138
        else:
            self.layer_idx = -2
139
            self.type_feature = "patch"
140

141
        self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx)
142
143
144
145

        pe_weight = self.img_processor.embeddings.position_embedding.weight
        L, D = pe_weight.size()
        H = int(math.sqrt(L))
146
        assert H**2 == L, f"position embedding size {L} is not square"
147
148
149
150
151
        if H % 2 != 0:
            self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
            H += 1
        image_dim_out = D
        # ((448/14)//2)**2
152
        self.num_img_tokens = (H // 2) ** 2
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        self.base_feat_height_target = H

        self.image_dim_out = image_dim_out
        self.img_sizes = None
        self.image_attention_mask = None

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = True
        self.with_learnable_separator = True
        self.hd_transform_order = "sub_glb"
        self.freeze_img_processor = False
        self.crop_size = 448

        # image token compression
167
        self.image_token_compression_cls = "avg_pool_2d"
168
169
170
171
172
        self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
        self.base_feat_height_reduction = 1
        self.base_feat_height_target = self.base_feat_height_target // 2

        # with_hd_transform and with_learnable_separator should have same value
173
174
175
176
        assert self.use_hd_transform == self.with_learnable_separator, (
            "use_hd_transform and with_learnable_separator should have same value"
        )
        assert self.use_hd_transform, "learnable separator is only for hd transform"
177
178
        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(
179
180
            torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])
        )
181
        self.sub_GN = nn.Parameter(
182
183
184
185
            torch.zeros(
                [1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2]
            )
        )
186
187
188
189

        dim_projection = hidden_size
        depth = 2
        layers = [
190
191
192
            nn.Linear(
                image_dim_out * self.base_feat_height_reduction**2, dim_projection
            )
193
194
        ]
        for _ in range(1, depth):
195
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
196
197
198
199
200
201
202
        self.img_projection = nn.Sequential(*layers)

        self.vocab_size = config.vocab_size
        self.img_features = None

        self.use_out_place_operations = False

203
204
205
206
207
208
    def get_img_features(
        self, img_embeds: torch.FloatTensor, attention_mask=None
    ) -> torch.FloatTensor:
        img_feature = self.img_processor(
            img_embeds, patch_attention_mask=attention_mask
        )
209

210
        if self.type_feature == "patch":
211
212
213
            patch_feature = img_feature

            use_token_compression = self.image_token_compression is not None
214
            use_padding = getattr(self, "img_processor_padding", None) is not None
215
216
217
            if use_token_compression or use_padding:
                # reshape to 2D tensor
                width = int(math.sqrt(patch_feature.size(1)))
218
219
220
                patch_feature = patch_feature.view(
                    -1, width, width, patch_feature.size(-1)
                )
221
222
223
224
225
226
227
228
229
230
231
232
233
                # convert to NCHW
                patch_feature = patch_feature.permute(0, 3, 1, 2)

                if use_padding:
                    patch_feature = self.img_processor_padding(patch_feature)
                if use_token_compression:
                    patch_feature = self.image_token_compression(patch_feature)

                # convert to NHWC
                patch_feature = patch_feature.permute(0, 2, 3, 1)
                patch_feature = patch_feature.view(
                    -1,
                    patch_feature.size(1) * patch_feature.size(2),
234
235
                    patch_feature.size(-1),
                )
236
237
238
239
240

            return patch_feature

        raise NotImplementedError

241
242
243
244
245
246
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        image_sizes: torch.Tensor,
        image_attention_mask: torch.Tensor,
    ) -> list[torch.FloatTensor]:
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        image_sizes: [[h1, w1], [h2, w2]]
        image_attention_mask: num_images x num_crops x 32 x 32
        output: (num_images, num_img_tokens, hidden_size)
        """

        # eg
        # pixel_values: torch.Size([1, 7, 3, 448, 448])
        # image_sizes: tensor([[ 896, 1344]], device='cuda:0')
        # output: torch.Size([1, 1841, 3072])

        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        img_sizes = image_sizes
        num_images, num_crops, c, h, w = pixel_values.shape
        bs = num_images
        pixel_values = pixel_values.flatten(0, 1)

        img_features = self.get_img_features(
            pixel_values,
275
276
            image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device),
        )
277
278
279
280
281

        base_feat_height_target = self.base_feat_height_target
        base_resolution = self.crop_size
        base_feat_height_reduction = self.base_feat_height_reduction

282
283
284
285
286
287
288
289
290
        base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1]))
        assert (
            base_feat_height == base_feat_height_target
            and base_feat_width == base_feat_height_target
        ), (
            f"base_feat_height: {base_feat_height}, "
            f"base_feat_width: {base_feat_width}, "
            f"expect {base_feat_height_target} features for hd transform"
        )
291
292

        # bs x max_num_crops x (24x24) x C
293
294
295
        img_features = img_features.view(
            bs, -1, base_feat_height * base_feat_width, self.image_dim_out
        )
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        C = self.image_dim_out
        H = base_feat_height

        output_imgs = []
        output_len = []
        # training is tensor, inference is list
        if isinstance(img_sizes, torch.Tensor):
            img_sizes = img_sizes.view(-1, 2)
        for _bs in range(bs):
            h, w = img_sizes[_bs]
            h = h // base_resolution
            w = w // base_resolution
            B_ = h * w

            # 1 x (24x24) x 1024
            global_img_feature = img_features[_bs, :1]

            # 1 x 12 x 12 x 4096
314
315
316
317
318
319
            glb_img = (
                global_img_feature.reshape(1, H, H, C)
                .reshape(
                    1,
                    H // base_feat_height_reduction,
                    base_feat_height_reduction,
320
                    H // base_feat_height_reduction,
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                    base_feat_height_reduction,
                    C,
                )
                .contiguous()
                .permute(0, 1, 3, 2, 4, 5)
                .reshape(
                    1,
                    H // base_feat_height_reduction,
                    H // base_feat_height_reduction,
                    base_feat_height_reduction * base_feat_height_reduction * C,
                )
                .contiguous()
            )
            temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
335
336
337

            # 1 x 156 x 4096
            glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
338
339
                1, -1, base_feat_height_reduction * base_feat_height_reduction * C
            )
340
341
342
343
344
345
346
347
348

            # (max_num_crops-1) x (12x12) x C
            sub_img = img_features[_bs, 1:]
            # 16x574x1024
            # get rid of padding sub_img
            sub_img = sub_img[:B_]

            # (num_crops, 12, 2, 12, 2, 1024) ->
            # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
            sub_img = (
                sub_img.reshape(B_, H, H, C)
                .reshape(
                    B_,
                    H // base_feat_height_reduction,
                    base_feat_height_reduction,
                    H // base_feat_height_reduction,
                    base_feat_height_reduction,
                    C,
                )
                .contiguous()
                .permute(0, 1, 3, 2, 4, 5)
                .reshape(
                    B_, -1, base_feat_height_reduction * base_feat_height_reduction * C
                )
                .contiguous()
            )
            sub_img = (
                sub_img.reshape(
                    1,
                    h,
                    w,
                    base_feat_height // base_feat_height_reduction,
                    base_feat_width // base_feat_height_reduction,
                    -1,
                )
                .permute(0, 1, 3, 2, 4, 5)
                .reshape(
                    1,
                    h * base_feat_height // base_feat_height_reduction,
379
                    w * base_feat_width // base_feat_height_reduction,
380
381
382
383
384
385
386
387
388
389
390
                    base_feat_height_reduction * base_feat_height_reduction * C,
                )
            )

            if image_attention_mask is not None and len(image_attention_mask) > 0:
                reshaped_image_attention_mask = (
                    image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2]
                    .reshape(
                        1,
                        h,
                        w,
391
                        base_feat_height // base_feat_height_reduction,
392
393
394
395
396
397
398
399
400
401
402
                        base_feat_width // base_feat_height_reduction,
                    )
                    .permute(0, 1, 3, 2, 4)
                    .reshape(
                        1,
                        h * base_feat_height // base_feat_height_reduction,
                        w * base_feat_width // base_feat_height_reduction,
                    )
                )
                useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
                useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
403
404
                sub_img = sub_img[:, :useful_height, :useful_width]
                temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
405
406
407
408
409
                temp_len = (
                    int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item())
                    + (useful_height + 1)
                    + base_feat_height // base_feat_height_reduction
                )
410
411
            else:
                temp_sub_GN = self.sub_GN.repeat(
412
413
414
415
416
417
418
                    1, h * base_feat_height // base_feat_height_reduction, 1, 1
                )
                temp_len = int(
                    (h * w + 1) * self.num_img_tokens
                    + 1
                    + (h + 1) * base_feat_height // base_feat_height_reduction
                )
419
420

            sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
421
422
                1, -1, base_feat_height_reduction * base_feat_height_reduction * C
            )
423
424
425
            # (1, num_img_tokens, 1024*4)

            # glb + sub
426
427
428
429
            if self.hd_transform_order == "glb_sub":
                output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
            elif self.hd_transform_order == "sub_glb":
                output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
430
431
432
            else:
                raise NotImplementedError(
                    f'hd_transform_order = {self.hd_transform_order}, "\
433
434
                        "not implemented'
                )
435

436
437
438
            # temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
            assert temp_len == output_imgs[-1].shape[1], (
                f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\
439
                    "{output_imgs[-1].shape[1]}'
440
            )
441
442
443
444
445
446

            output_len.append(temp_len)

        img_set_tensor = []
        for _output_img in output_imgs:
            img_feature_proj = self.img_projection(
447
448
                _output_img.to(target_device).to(target_dtype)
            )
449
            img_set_tensor.append(img_feature_proj.squeeze(0))
450
451
452
453

        return img_set_tensor


454
class Phi4MMImagePixelInputs(TensorSchema):
455
    """
456
457
458
459
460
461
462
463
464
    Dimensions:
        - bn: Batch size * number of images
        - p: Number of patches (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
        - nc: Number of crops
        - H_mask: Height of attention mask
        - W_mask: Width of attention mask
465
    """
466

467
    type: Literal["pixel_values"]
468

469
    pixel_values: Annotated[
470
        torch.Tensor | list[torch.Tensor],
471
472
473
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # may be different per batch and image
474
475
476
477
478
479
    ]

    image_sizes: Annotated[
        torch.Tensor,
        TensorShape("bn", 2),  # (height, width)
    ]
480

481
482
483
484
    num_img_tokens: Annotated[
        list[int],
        TensorShape("bn"),
    ]
485

486
487
488
489
    image_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("bn", "nc", 32, 32),  # H_mask, W_mask
    ]
490
491


492
493
494
495
496
497
498
class Phi4MMAudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of audios
        - t: Time frames (M)
    """

499
    type: Literal["audio_features"]
500

501
    audio_features: Annotated[
502
        torch.Tensor | list[torch.Tensor],
503
504
        TensorShape("bn", "t", 80, dynamic_dims={"t"}),
    ]
505
506


507
508
509
510
511
512
513
514
class Phi4MMAudioEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - n: Number of audios
        - f: Audio feature size
        - h: Hidden size (must match language model backbone)
    """
515

516
    type: Literal["audio_embeds"]
517
518
519
520
    data: Annotated[
        NestedTensors,
        TensorShape("b", "n", "f", "h"),
    ]
521
522


523
Phi4MMAudioInputs: TypeAlias = Phi4MMAudioFeatureInputs | Phi4MMAudioEmbeddingInputs
524
525


526
def cat_with_pad(tensors, dim, padding_value=0):
527
    """
528
    cat along dim, while pad to max for all other dims
529
    """
530
    ndim = tensors[0].dim()
531
532
533
    assert all(t.dim() == ndim for t in tensors[1:]), (
        "All tensors must have the same number of dimensions"
    )
534

535
536
537
    out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
    out_size[dim] = sum(t.shape[dim] for t in tensors)
    output = tensors[0].new_full(out_size, padding_value)
538

539
540
541
542
543
544
    index = 0
    for t in tensors:
        # Create a slice list where every dimension except dim is full slice
        slices = [slice(0, t.shape[d]) for d in range(ndim)]
        # Update only the concat dimension slice
        slices[dim] = slice(index, index + t.shape[dim])
545

546
547
        output[slices] = t
        index += t.shape[dim]
548

549
    return output
550
551


552
553
554
class Phi4MMProcessingInfo(BaseProcessingInfo):
    @property
    def image_tokens(self) -> list[str]:
555
        return [f"<|image_{i + 1}|>" for i in range(100)]
556

557
558
    @property
    def audio_tokens(self) -> list[str]:
559
        return [f"<|audio_{i + 1}|>" for i in range(100)]
560

561
562
    def get_dynamic_hd(
        self,
563
        processor: ProcessorMixin | None = None,
564
565
566
567
568
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()
        image_processor = processor.image_processor
        return image_processor.dynamic_hd
569

570
    def get_feature_extractor(self, **kwargs: object) -> SequenceFeatureExtractor:
571
        return self.get_hf_processor(**kwargs).audio_processor
572

573
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
574
        return {"audio": None, "image": None}
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
    def _find_target_aspect_ratio(
        self,
        orig_width: int,
        orig_height: int,
        image_size: int,
        max_num: int,
        min_num: int,
    ):
        w_crop_num = math.ceil(orig_width / float(image_size))
        h_crop_num = math.ceil(orig_height / float(image_size))
        if w_crop_num * h_crop_num > max_num:
            aspect_ratio = orig_width / orig_height

            # calculate the existing image aspect ratio
590
591
592
593
594
595
            target_ratios = set(
                (i, j)
                for i in range(1, max_num + 1)
                for j in range(1, max_num + 1)
                if i * j <= max_num and i * j >= min_num
            )
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

            # find the closest aspect ratio to the target
            image_processor = self.get_hf_processor().image_processor
            target_aspect_ratio = image_processor.find_closest_aspect_ratio(
                aspect_ratio,
                target_ratios,
                orig_width,
                orig_height,
                image_size,
            )

            # calculate the target width and height
            target_width = image_size * target_aspect_ratio[0]
            target_height = image_size * target_aspect_ratio[1]
        else:
            target_width = image_size * w_crop_num
            target_height = image_size * h_crop_num
            target_aspect_ratio = (w_crop_num, h_crop_num)
        return target_aspect_ratio, target_height, target_width
616

617
618
619
620
621
622
623
624
625
626
627
    def _compute_num_image_tokens(
        self,
        orig_width: int,
        orig_height: int,
        dynamic_hd_size: int,
        vit_image_size: int,
        vit_patch_size: int,
        token_compression_factor: int = 2,
    ):
        """
        compute the number of tokens an image is expected to take up considering
628
        the image encoder architecture and exclude output features containing
629
        only padding pixels
630

631
        for siglip, vit_image_size=448, vit_patch_size=14, so output will be
632
633
634
635
        32x32 feature map
        NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
        """
        assert vit_image_size % vit_patch_size == 0, (
636
637
638
639
640
641
            "vit_image_size must be divisible by vit_patch_size"
        )
        assert vit_image_size // vit_patch_size % token_compression_factor == 0, (
            "vit_image_size // vit_patch_size must be divisible by "
            "token_compression_factor"
        )
642
643

        target_aspect_ratio, target_height, target_width = (
644
645
646
647
            self._find_target_aspect_ratio(
                orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1
            )
        )
648
        assert target_aspect_ratio[0] * vit_image_size == target_width, (
649
650
            f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
        )
651
        assert target_aspect_ratio[1] * vit_image_size == target_height, (
652
653
654
655
656
            f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
        )
        assert (
            target_height % vit_image_size == 0 and target_width % vit_image_size == 0
        )
657
658

        padding_height, padding_width = _get_padding_size(
659
660
661
            orig_width, orig_height, target_height, target_width
        )
        assert padding_width == 0 or padding_height == 0, (
662
            "padding_width or padding_height must be 0"
663
        )
664
665
666
667
668
669

        target_feat_width = target_width // vit_patch_size
        target_feat_height = target_height // vit_patch_size
        if padding_width >= vit_patch_size:
            assert padding_height == 0, "padding_height not 0"
            non_pad_feat_width = target_feat_width - math.floor(
670
671
                padding_width / vit_patch_size
            )
672
673
674
675
            non_pad_feat_height = target_feat_height
        elif padding_height >= vit_patch_size:
            assert padding_width == 0, "padding_width not 0"
            non_pad_feat_height = target_feat_height - math.floor(
676
677
                padding_height / vit_patch_size
            )
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
            non_pad_feat_width = target_feat_width
        else:
            # small padding shorter than a vit patch
            non_pad_feat_width = target_feat_width
            non_pad_feat_height = target_feat_height

        feat_width = non_pad_feat_width // token_compression_factor
        feat_height = non_pad_feat_height // token_compression_factor
        # NOTE it's possible that the non-padding feature is not divisible
        if non_pad_feat_width % token_compression_factor != 0:
            feat_width += 1
        if non_pad_feat_height % token_compression_factor != 0:
            feat_height += 1
        num_hd_patch_tokens = feat_width * feat_height
        num_hd_newline_tokens = feat_height
        vit_feature_size = vit_image_size // vit_patch_size
694
        num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2
695
        num_sep_tokens = 1
696
697
698
699
700
701
702
703
704
        num_global_image_newline_tokens = vit_feature_size // token_compression_factor

        return (
            num_global_image_tokens
            + num_sep_tokens
            + num_hd_patch_tokens
            + num_hd_newline_tokens
            + num_global_image_newline_tokens
        )
705
706
707
708
709
710

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
711
        processor: ProcessorMixin | None = None,
712
713
714
715
716
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_name = hf_config.img_processor
        if vision_encoder_name is None:
            vision_encoder_name = SIGLIP_NAME
717
718
719
720
        prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
        vit_image_size = prepro_config["vit_image_size"]
        vit_patch_size = prepro_config["vit_patch_size"]
        token_compression_factor = prepro_config["token_compression_factor"]
721
722
723
724
725
726
727
728
729
730
731

        dynamic_hd_size = self.get_dynamic_hd(processor=processor)

        image_num_tokens = self._compute_num_image_tokens(
            image_width,
            image_height,
            dynamic_hd_size=dynamic_hd_size,
            vit_image_size=vit_image_size,
            vit_patch_size=vit_patch_size,
            token_compression_factor=token_compression_factor,
        )
732

733
        return image_num_tokens
734

735
736
    def get_image_size_with_most_features(
        self,
737
        processor: ProcessorMixin | None = None,
738
739
740
741
742
    ) -> ImageSize:
        hf_config = self.get_hf_config()
        vision_encoder_name = hf_config.img_processor
        if vision_encoder_name is None:
            vision_encoder_name = SIGLIP_NAME
743
744
        prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
        vit_image_size = prepro_config["vit_image_size"]
745
746
747
748
749
750
751

        max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
        return ImageSize(height=max_side, width=vit_image_size)

    def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
        """
        Compute the output size of the `extract_features` method.
752

753
754
755
        Args:
            audio_len (int): Length of the input waveform in samples.
            sr (float): Sampling rate of the waveform, either 16000 or 8000.
756

757
758
759
760
761
        Returns:
            tuple (int, int): Output size as (T, D), where:
                T: Number of time frames.
                D: Number of Mel filterbank bins (80).
        """
762

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        # Resample to 16000 or 8000 if needed
        if sr > 16000:
            audio_len //= sr // 16000
        elif 8000 <= sr < 16000:
            # We'll resample to 16K from 8K
            audio_len *= 2
        elif sr < 8000:
            raise RuntimeError(f"Unsupported sample rate {sr}")

        # Spectrogram parameters for 16 kHz
        win_length = 400  # Frame length in samples
        hop_length = 160  # Frame shift in samples

        # Calculate number of frames (T)
        num_frames = (audio_len - win_length) // hop_length + 1
        if num_frames < 1:
            raise ValueError("Waveform too short for given parameters.")

        # Return time frames (T)
        return num_frames

    def _compute_audio_embed_size(self, audio_frames: int) -> int:
        """
        Compute the audio embedding size based on the audio frames and
        compression rate.
        """
        hf_config = self.get_hf_config()
790
        compression_rate = hf_config.embd_layer["audio_embd_layer"]["compression_rate"]
791
792
793
794
795
        # NOTE: this is a hard-coded value but might be configurable
        # in the future
        qformer_compression_rate = 1
        integer = audio_frames // compression_rate
        remainder = audio_frames % compression_rate
796

797
        result = integer if remainder == 0 else integer + 1
798

799
800
801
802
        integer = result // qformer_compression_rate
        remainder = result % qformer_compression_rate
        # qformer compression
        result = integer if remainder == 0 else integer + 1
803

804
        return result
805
806


807
808
809
810
class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)
811

812
813
        image_tokens: list[str] = self.info.image_tokens[:num_images]
        audio_tokens: list[str] = self.info.audio_tokens[:num_audios]
814

815
        return "".join(image_tokens + audio_tokens)
816

817
818
819
820
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
821
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
822
823
824
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)
825

826
        target_width, target_height = self.info.get_image_size_with_most_features()
827

828
829
830
        image_overrides = mm_options.get("image") if mm_options else None
        audio_overrides = mm_options.get("audio") if mm_options else None

831
        mm_data = {
832
833
834
835
836
837
838
839
840
841
842
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "audio": self._get_dummy_audios(
                length=_AUDIO_MAX_SOUNDFILE_SIZE,
                num_audios=num_audios,
                overrides=audio_overrides,
            ),
843
844
        }

845
        return mm_data
846
847


848
849
850
class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
851
852
853
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate, audio_resample_method="scipy"
        )
854

855
856
857
858
859
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
860
        tok_kwargs: Mapping[str, object],
861
862
863
864
865
866
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

867
        sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
868
869
        if audio_data := mm_data.get("audios", []):
            mm_data["audios"] = [(data, sr) for data in audio_data]
870

871
872
873
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
874
875

        num_img_tokens = [
876
877
878
            self.info.get_num_image_tokens(
                image_width=img_size[0], image_height=img_size[1]
            )
879
880
881
            for img_size in processed_outputs["image_sizes"]
        ]
        processed_outputs["num_img_tokens"] = num_img_tokens
882

883
        audio_features = processed_outputs["input_audio_embeds"]
884
        feature_sizes = [
885
            self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data
886
        ]
887
888
        processed_outputs["input_audio_embeds"] = [
            audio_features[idx, :size] for idx, size in enumerate(feature_sizes)
889
        ]
890

891
        return processed_outputs
892

893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            input_image_embeds=MultiModalFieldConfig.batched("image"),
            image_attention_mask=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            num_img_tokens=MultiModalFieldConfig.batched("image"),
            input_audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
910
        out_mm_kwargs: MultiModalKwargsItems,
911
912
913
    ) -> Sequence[PromptUpdate]:
        image_tokens: list[str] = self.info.image_tokens  # type: ignore
        audio_tokens: list[str] = self.info.audio_tokens  # type: ignore
914
        feature_extractor = self.info.get_feature_extractor(**hf_processor_mm_kwargs)
915
916
917
918
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        def get_image_replacement_phi4mm(item_idx: int):
            images = mm_items.get_items(
919
920
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
921
922
923
924
925
926
927
928
929
930
931

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

932
            return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
933
934
935
936
937
938

        def get_audio_replacement_phi4mm(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            # TODO(Isotr0py): support embedding inputs
            audio_len = audios.get_audio_length(item_idx)
            audio_frames = self.info.get_audio_num_frames(
939
940
941
                audio_len, feature_extractor.sampling_rate
            )
            audio_embed_size = self.info._compute_audio_embed_size(audio_frames)
942

943
            return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
944

945
        return [
946
947
            PromptReplacement(
                modality="image",
948
                target=image_tokens.__getitem__,
949
                replacement=get_image_replacement_phi4mm,
950
            ),
951
952
            PromptReplacement(
                modality="audio",
953
                target=audio_tokens.__getitem__,
954
                replacement=get_audio_replacement_phi4mm,
955
            ),
956
        ]
957

958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            image_tokens: list[str] = self.info.image_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])
        elif cached_update.modality == "audio":
            audio_tokens: list[str] = self.info.audio_tokens  # type: ignore
            new_update = new_update.with_target(audio_tokens[new_item_idx])

        return new_update

977

978
979
980
981
982
983
@MULTIMODAL_REGISTRY.register_processor(
    Phi4MMMultiModalProcessor,
    info=Phi4MMProcessingInfo,
    dummy_inputs=Phi4MMDummyInputsBuilder,
)
class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
984
    """
985
    Implements the Phi-4-multimodal-instruct model in vLLM.
986
    """
987

988
989
    merge_by_field_config = True

990
991
992
993
994
995
996
997
998
    packed_modules_mapping = {
        "qkv_proj": [
            "qkv_proj",
        ],
        "gate_up_proj": [
            "gate_up_proj",
        ],
    }

999
1000
1001
1002
1003
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "base_layer.": "",
        },
        orig_to_new_prefix={
1004
1005
            "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.",  # noqa: E501
            "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.",  # noqa: E501
1006
1007
1008
1009
1010
            "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
            "model.embed_tokens_extend.image_embed.": "vision_encoder.",
        },
    )

1011
    @classmethod
1012
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1013
1014
1015
1016
1017
1018
1019
        if modality.startswith("image"):
            return f"<|image_{i}|>"
        if modality.startswith("audio"):
            return f"<|audio_{i}|>"

        raise ValueError("Only image or audio modality is supported")

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        assert multimodal_config, "multimodal_config is required"
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config
        self.lora_config = lora_config

        # Tensor/Pipeline parallel not supported for now.
1034
        assert get_pp_group().world_size == 1, "pipeline parallel is not supported"
1035
1036
1037
1038
1039

        self.vision_encoder = Phi4MMImageEncoder(
            config,
            quant_config,
            prefix="model.vision_embed_tokens",
1040
1041
            model_dir=config._name_or_path,
        )
1042
1043
1044

        if isinstance(config.embd_layer["audio_embd_layer"], dict):
            embedding_config = {
1045
                "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"],
1046
1047
1048
1049
1050
1051
1052
1053
                **config.embd_layer["audio_embd_layer"],
            }
        else:
            embedding_config = {
                "embedding_cls": self.config.embd_layer["embedding_cls"]
            }

        self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
1054
1055
1056
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1057
1058
1059
1060
1061
1062
1063
1064

        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
1065
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
1066
            quant_config=quant_config,
1067
            prefix=maybe_prefix(prefix, "lm_head"),
1068
1069
1070
1071
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
        logit_scale = getattr(config, "logit_scale", 1.0)
1072
1073
1074
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size, logit_scale
        )
1075
1076

    def _parse_and_validate_audio_input(
1077
        self, **kwargs: object
1078
    ) -> Phi4MMAudioInputs | None:
1079
        """
1080
        Parse and validate the audio input to the model.  This handles both
1081
1082
1083
1084
1085
1086
1087
1088
1089
        audio features and audio embeddings, but only the former is used for
        now.

        Args:
            kwargs (object): Keyword arguments.

        Returns:
            Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
        """
1090
        audio_features = kwargs.pop("input_audio_embeds", None)
1091
1092
1093
1094
1095
1096
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
1097
            return Phi4MMAudioFeatureInputs(
1098
1099
                type="audio_features",
                audio_features=audio_features,
1100
            )
1101
1102

        if audio_embeds is not None:
1103
            return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
1104
1105
1106

        raise AssertionError("This line should be unreachable.")

1107
1108
1109
    def _process_audio_input(
        self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str
    ) -> NestedTensors:
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
        """
        Create the audio embeddings from the audio input, where the audio input
        is pairs of audio features and audio embed lengths.  The audio input is
        created by `input_mapper_for_phi4mm_audio`.

        Args:
            audio_input (Phi4MMAudioInputs): Audio input.

        Returns:
            NestedTensors: Audio embeddings
        """
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

1124
        audio_features = audio_input["audio_features"]
1125
1126
1127
        # (e.g. multiple examples) and the second dim is the multi-audio dim
        # (e.g. multiple audios in the same example)

1128
1129
1130
1131
1132
        dtype = next(self.embed_tokens_extend.parameters()).dtype
        audio_embeds = [
            self.embed_tokens_extend(
                features.to(dtype),
                audio_projection_mode=audio_projection_mode,
1133
1134
            )
            for features in audio_features
1135
1136
        ]
        return audio_embeds
1137

1138
    def _parse_and_validate_image_input(
1139
        self, **kwargs: object
1140
    ) -> Phi4MMImagePixelInputs | None:
1141
1142
        pixel_values = kwargs.get("input_image_embeds")
        if pixel_values is None:
1143
1144
1145
1146
1147
            return None

        image_sizes = kwargs.get("image_sizes")
        image_attention_mask = kwargs.get("image_attention_mask")
        num_img_tokens = kwargs.get("num_img_tokens")
1148
1149
1150
1151
1152
        assert (
            image_sizes is not None
            and image_attention_mask is not None
            and num_img_tokens is not None
        ), "Missing image inputs"
1153

1154
1155
        return Phi4MMImagePixelInputs(
            type="pixel_values",
1156
            pixel_values=pixel_values,
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
            image_sizes=image_sizes,
            image_attention_mask=image_attention_mask,
            num_img_tokens=num_img_tokens,
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
            if (
                input_key in ("input_image_embeds", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("input_audio_embeds", "audio_embeds")
                and "audios" not in modalities
            ):
                modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
1178
1179
1180
1181

        return modalities

    def _process_image_input(
1182
1183
        self, image_input: Phi4MMImagePixelInputs
    ) -> list[torch.Tensor]:
1184
        dtype = next(self.vision_encoder.parameters()).dtype
1185
        pixel_values = image_input["pixel_values"].to(dtype)
1186
1187
1188
1189
1190
        image_sizes = image_input["image_sizes"]
        image_attention_mask = image_input["image_attention_mask"]
        image_embeds = self.vision_encoder(
            pixel_values, image_sizes, image_attention_mask
        )
1191
1192
        return image_embeds

1193
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1194
1195
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1196
            return []
1197

1198
        # The result multimodal_embeddings is tuple of tensors, with each
1199
        # tensor corresponding to a multimodal data item (image or video).
1200
1201
1202
1203
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
1204
        audio_projection_mode = "speech"
1205
1206
1207
1208
1209
        for modality in modalities:
            # make sure process images first
            if modality == "images":
                audio_projection_mode = "vision"
                image_input = modalities["images"]
1210
1211
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1212
1213
1214
            if modality == "audios":
                audio_input = modalities["audios"]
                audio_embeddings = self._process_audio_input(
1215
1216
                    audio_input, audio_projection_mode=audio_projection_mode
                )
1217
1218
1219
1220
                multimodal_embeddings += tuple(audio_embeddings)

        return multimodal_embeddings

1221
1222
1223
1224
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1225
1226
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1227
1228
1229
1230
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            inputs_embeds = None
1231

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        hidden_states = self.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1244
    ) -> torch.Tensor | None:
1245
        logits = self.logits_processor(self.lm_head, hidden_states)
1246
1247
        return logits

1248
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
1249
        loader = AutoWeightsLoader(self, skip_substrs=["lora"])
1250
1251
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

1252
1253
1254
1255
1256
1257
1258
1259
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.",
            connector=["audio_projection_for_vision", "audio_projection"],
            tower_model=["vision_encoder", "embed_tokens_extend"],
1260
        )
1261
1262
1263

    def get_language_model(self) -> torch.nn.Module:
        return self.model