minicpmo.py 30 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
26

tc-mb's avatar
tc-mb committed
27
import os
28
29
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
30
31
32

import torch
from torch import nn
33
from transformers import BatchFeature
34
from transformers.modeling_outputs import BaseModelOutputWithPast
35
36
37
38
39
40
from transformers.models.whisper.modeling_whisper import (
    ACT2FN,
    WhisperAttention,
    WhisperConfig,
    WhisperEncoder,
)
41
42

from vllm.config import VllmConfig
43
from vllm.config.multimodal import BaseDummyOptions
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioItem,
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityData,
    ModalityDataItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
63
from vllm.utils.tensor_schema import TensorSchema, TensorShape
64

65
66
67
from .minicpmv import (
    _MAX_FRAMES_PER_VIDEO,
    MiniCPMV2_6,
tc-mb's avatar
tc-mb committed
68
    MiniCPMV4_5,
69
70
71
72
73
74
    MiniCPMVDummyInputsBuilder,
    MiniCPMVMultiModalDataParser,
    MiniCPMVMultiModalProcessor,
    MiniCPMVProcessingInfo,
    _minicpmv_field_config,
)
75
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
76
77
78

CPU_DEVICE = torch.device("cpu")

tc-mb's avatar
tc-mb committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
if os.getenv("USE_FLAGOS") == "1":
    import flag_gems

    FLAG_GEMS_CONFIG = [
        "sort",
        "sort_stable",
        "layer_norm",
        "clamp_",
        "cos",
        "embedding",
        "exp",
        "exponential_",
        "full",
        "gather",
        "gelu",
        "index",
        "le",
        "lt",
        "lt_scalar",
        "masked_fill_",
        "max",
        "ones",
        "pow_scalar",
        "prod_dim",
        "rand_like",
        "reciprocal",
        "repeat",
        "scatter",
        "scatter_",
        "sin",
        "sub",
        "true_divide",
        "true_divide_",
        "uniform_",
        "where_scalar_self",
        "where_self_out",
        "zeros",
        "zeros_like",
    ]
    flag_gems.only_enable(record=False, include=FLAG_GEMS_CONFIG)

120

121
class MiniCPMOAudioFeatureInputs(TensorSchema):
122
    """
123
124
125
126
127
128
    Dimensions:
        - bns: Batch size * number of audios * number of slices
        - bn: Batch size * number of audios
        - c: Number of channels
        - l: Length
        - s: Number of slices
129
    """
130

131
    type: Literal["audio_features"] = "audio_features"
132

133
    audio_features: Annotated[
134
        torch.Tensor | list[torch.Tensor],
135
136
137
138
139
140
        TensorShape("bns", "c", "l", dynamic_dims={"l"}),
    ]
    """
    Slice here means chunk. Audio that is too long will be split into slices,
    which is the same as image. Padding is used therefore `audio_features` is 
    `torch.Tensor`.
141
142
    """

143
    audio_feature_lens: Annotated[
144
        torch.Tensor | list[torch.Tensor],
145
146
147
        TensorShape("bn", "s"),
    ]
    """
148
    This should be feature length of each audio slice, 
149
    which equals to `audio_features.shape[-1]`
150
151
152
    """


153
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
154
    """
155
156
157
158
    Dimensions:
        - bn: Batch size * number of audios
        - s: Number of slices
        - h: Hidden size (must match language model backbone)
159

160
161
    Length of each slice may vary, so pass it as a list.
    """
162

163
164
165
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
166
        torch.Tensor | list[torch.Tensor],
167
168
        TensorShape("bn", "s", "h", dynamic_dims={"s"}),
    ]
169

170

171
172
173
MiniCPMOAudioInputs: TypeAlias = (
    MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs
)
174
175


176
177
178
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        **_minicpmv_field_config(hf_inputs),
179
180
181
        audio_features=MultiModalFieldConfig.batched("audio"),
        audio_feature_lens=MultiModalFieldConfig.batched("audio"),
        audio_embeds=MultiModalFieldConfig.batched("audio"),
182
    )
183
184


185
186
187
188
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
189
190
191
192
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
193
194
195
196
197
    ) -> None:
        super().__init__(
            data,
            modality="image",
            required_fields={"audio_embeds"},
198
            fields_factory=fields_factory,
199
        )
200
201
202
203
204


class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
    def _parse_audio_data(
        self,
205
206
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
207
        if isinstance(data, dict):
208
209
            return MiniCPMOAudioEmbeddingItems(
                data,
210
                fields_factory=_minicpmo_field_config,
211
212
            )

213
214
215
216
217
218
        return super()._parse_audio_data(data)


class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
    audio_pattern = "(<audio>./</audio>)"

219
220
221
222
223
224
    def get_data_parser(self):
        return MiniCPMOMultiModalDataParser(
            target_sr=self.get_default_audio_sampling_rate(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

225
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
226
        return {**super().get_supported_mm_limits(), "audio": None}
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def get_audio_placeholder(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        hf_processor = self.get_hf_processor()

        return hf_processor.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )

242
    def get_default_audio_pool_step(self) -> int:
tc-mb's avatar
tc-mb committed
243
244
245
        hf_config = self.get_hf_config()
        # MiniCPM-o 4.5 uses pool_step=5, older versions use 2
        return getattr(hf_config, "audio_pool_step", 2)
246
247
248
249
250
251
252
253
254
255
256

    def get_default_audio_sampling_rate(self) -> int:
        return 16000

    def get_chunk_length(self) -> int:
        return self.get_hf_config().audio_chunk_length

    def get_max_audio_tokens_per_chunk(self) -> int:
        pool_step = self.get_default_audio_pool_step()
        fbank_feat_in_chunk = 100
        cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
257
        return (cnn_feat_in_chunk - pool_step) // pool_step + 1
258
259
260
261

    def get_max_audio_chunks_with_most_features(self) -> int:
        return 30

262
    def get_max_audio_tokens(self) -> int:
263
264
        num_chunks = self.get_max_audio_chunks_with_most_features()
        return self.get_max_audio_tokens_per_chunk() * num_chunks
265

266
267
    def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
        sampling_rate = self.get_default_audio_sampling_rate()
268
        num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
269
270
        return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1

271
272
273
274
275
276
277
278
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
        max_audios = mm_counts.get("audio", 0)
279

280
281
        max_image_tokens = self.get_max_image_tokens() * max_images
        max_audio_tokens = self.get_max_audio_tokens() * max_audios
282
283
284
285
286
287
        max_total_frames = self.get_max_video_frames(
            seq_len - max_image_tokens - max_audio_tokens
        )
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
288

289
        return max(max_frames_per_video, 1)
290
291


292
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
293
294
295
296
297
298
299
300
301
302
303
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        audio_prompt_texts = self.info.audio_pattern * num_audios

        return super().get_dummy_text(mm_counts) + audio_prompt_texts

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
304
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
305
        mm_processor_kwargs: Mapping[str, object] | None = None,
306
    ) -> MultiModalDataDict:
307
        num_audios = mm_counts.get("audio", 0)
308
309
310
311
        audio_len = (
            self.info.get_max_audio_chunks_with_most_features()
            * self.info.get_default_audio_sampling_rate()
        )
312

313
314
        audio_overrides = mm_options.get("audio") if mm_options else None

315
        audio_mm_data = {
316
317
318
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
319
320
        }

321
        return {
322
            **super().get_dummy_mm_data(seq_len, mm_counts, mm_options),
323
324
            **audio_mm_data,
        }
325
326


327
class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
328
329
330
331
332
333
334
335
336
337
338
    def get_audio_prompt_texts(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        return self.info.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )
339

340
341
342
343
    def process_audios(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
344
        tok_kwargs: Mapping[str, object],
345
    ) -> Mapping[str, NestedTensors]:
346
347
348
        if (audios := mm_data.get("audios")) is None:
            return {}

349
350
        mm_items = self.info.parse_mm_data({"audio": audios}, validate=False)
        parsed_audios = mm_items.get_items(
351
            "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
352
        )
353

354
355
356
357
358
359
        if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
            audio_inputs = {}
        else:
            audio_inputs = self._base_call_hf_processor(
                prompts=[self.info.audio_pattern] * len(parsed_audios),
                mm_data={"audios": [[audio] for audio in parsed_audios]},
360
                mm_kwargs={**mm_kwargs, "chunk_input": True},
361
                tok_kwargs=tok_kwargs,
362
                out_keys={"audio_features", "audio_feature_lens"},
363
            )
364
365
366
367

            # Avoid padding since we need the output for each audio to be
            # independent of other audios for the cache to work correctly
            unpadded_audio_features = [
368
369
                feat[:, :feature_len]
                for feat, feature_len in zip(
370
371
372
373
374
375
376
                    audio_inputs["audio_features"],
                    audio_inputs["audio_feature_lens"],
                )
            ]
            audio_inputs["audio_features"] = unpadded_audio_features

        return audio_inputs
377

378
379
380
381
    def process_mm_inputs(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
382
        tok_kwargs: Mapping[str, object],
383
    ) -> Mapping[str, NestedTensors]:
384
        return {
385
386
            **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs),
            **self.process_audios(mm_data, mm_kwargs, tok_kwargs),
387
388
        }

389
    def _get_prompt_updates(
390
391
392
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
393
        out_mm_kwargs: MultiModalKwargsItems,
394
395
396
397
398
399
    ) -> Sequence[PromptUpdate]:
        base_updates = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
400

401
402
403
404
        audio_placeholder = self.info.audio_pattern

        def get_audio_replacement(item_idx: int):
            audios = mm_items.get_items(
405
406
                "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
            )
407
408
409
410

            if isinstance(audios, MiniCPMOAudioEmbeddingItems):
                single_audio_embeds = audios.get(item_idx)["audio_embeds"]
                audio_len = self.info.get_audio_len_by_num_chunks(
411
412
                    sum(map(len, single_audio_embeds))
                )
413
414
415
            else:
                audio_len = audios.get_audio_length(item_idx)

416
417
418
419
            return PromptUpdateDetails.select_text(
                self.get_audio_prompt_texts(audio_len),
                "<unk>",
            )
420
421

        return [
422
            *base_updates,
423
424
425
426
427
            PromptReplacement(
                modality="audio",
                target=audio_placeholder,
                replacement=get_audio_replacement,
            ),
428
429
430
431
        ]

    def _get_mm_fields_config(
        self,
432
        hf_inputs: BatchFeature,
433
434
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
435
        return _minicpmo_field_config(hf_inputs)
436
437
438
439
440


class MultiModalProjector(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
441
        self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
442
        self.relu = nn.ReLU()
443
        self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
444
445
446
447
448
449
450
451

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.relu(self.linear1(audio_features))
        hidden_states = self.linear2(hidden_states)
        return hidden_states


class MiniCPMWhisperEncoderLayer(nn.Module):
452
    def __init__(self, config: WhisperConfig, layer_idx: int):
453
454
        super().__init__()
        self.embed_dim = config.d_model
455
456
457
458
459
460
461
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
            layer_idx=layer_idx,
        )
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
477
        hidden_states, _ = self.self_attn(
478
479
480
            hidden_states=hidden_states,
            attention_mask=attention_mask,
        )
481
482
483
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
484
485
486
487
488
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
489
490
491
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.activation_dropout, training=self.training
        )
492
        hidden_states = self.fc2(hidden_states)
493
494
495
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
496
497
        hidden_states = residual + hidden_states

498
499
        if hidden_states.dtype == torch.float16:
            hidden_states = cast_overflow_tensors(hidden_states)
500

501
        outputs = (hidden_states,)
502
503
504
505
506
507
508

        return outputs


class MiniCPMWhisperEncoder(WhisperEncoder):
    def __init__(self, config: WhisperConfig):
        super().__init__(config)
509
510
511
512
513
514
        self.layers = nn.ModuleList(
            [
                MiniCPMWhisperEncoderLayer(config, layer_idx=i)
                for i in range(config.encoder_layers)
            ]
        )
515
516
517
518

    def forward(
        self,
        input_features: torch.Tensor,
519
        attention_mask: torch.Tensor | None = None,
520
521
    ) -> BaseModelOutputWithPast:
        # Ignore copy
522
523
524
        input_features = input_features.to(
            dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
        )
525
526
527
528
529
530
531
532

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)

        embed_pos = self.embed_positions.weight

533
        embed_pos = embed_pos[: inputs_embeds.shape[1], :]
534
535

        hidden_states = inputs_embeds + embed_pos
536
537
538
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
539
540
541
542

        encoder_states = ()

        for idx, encoder_layer in enumerate(self.layers):
543
            encoder_states = encoder_states + (hidden_states,)
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            # Ignore copy
            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                )

                hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
562
        encoder_states = encoder_states + (hidden_states,)
563
564
565
566
567
568
569

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )


tc-mb's avatar
tc-mb committed
570
571
572
class MiniCPMOBaseModel:
    """Base mixin class for MiniCPM-O models with audio support."""

573
574
575
576
577
578
579
580
581
582
583
584
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

585
    @classmethod
586
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
587
588
589
590
591
592
593
594
595
        if modality.startswith("image"):
            return "(<image>./</image>)"
        if modality.startswith("video"):
            return "(<video>./</video>)"
        if modality.startswith("audio"):
            return "(<audio>./</audio>)"

        raise ValueError("Only image, video or audio modality is supported")

596
597
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
598
599
600
601
602

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )
603
604
605
606
607
608

    def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Do not use parameters temporarily
        audio_config = self.config.audio_config
        model = MiniCPMWhisperEncoder(audio_config)
        audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
609
610
611
612
613
614
        self.audio_avg_pooler = nn.AvgPool1d(
            self.config.audio_pool_step, stride=self.config.audio_pool_step
        )
        self.audio_projection_layer = MultiModalProjector(
            in_dim=audio_output_dim, out_dim=self.embed_dim
        )
615
616
617
        self.audio_encoder_layer = -1
        return model

618
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
619
620
621
622
623
624
625
626
627
628
629
630
        loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
        return loader.load_weights(weights)

    def subsequent_chunk_mask(
        self,
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = CPU_DEVICE,
        num_lookhead: int = 0,
    ) -> torch.Tensor:
        ret = torch.zeros(size, size, device=device, dtype=torch.bool)
631
632
        # Vectorized computation of row indices and chunk boundaries
        row_indices = torch.arange(size, device=device)
Cyrus Leung's avatar
Cyrus Leung committed
633
        chunk_indices = row_indices // chunk_size
634
635
636
637
638
        if num_left_chunks < 0:
            # If num_left_chunks < 0, start is always 0 for all rows
            start_indices = torch.zeros_like(row_indices)
        else:
            # Compute start indices vectorially
639
            start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0)
Cyrus Leung's avatar
Cyrus Leung committed
640
            start_indices = start_chunk_indices * chunk_size
641
642
        # Compute ending indices vectorially
        end_chunk_indices = chunk_indices + 1
643
644
645
        end_indices = torch.clamp(
            end_chunk_indices * chunk_size + num_lookhead, max=size
        )
646
647
648
649
650
651
        # Create column indices for broadcasting
        col_indices = torch.arange(size, device=device).unsqueeze(0)
        start_indices = start_indices.unsqueeze(1)
        end_indices = end_indices.unsqueeze(1)
        # Vectorized mask creation
        ret = (col_indices >= start_indices) & (col_indices < end_indices)
652
653
        return ret

654
    def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
655
656
        input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
        input_lengths_after_pooling = (
657
658
659
            input_lengths_after_cnn - self.config.audio_pool_step
        ) // self.config.audio_pool_step + 1
        input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
660
661
662

        return input_lengths_after_cnn, input_lengths_after_pooling

663
    def get_audio_hidden_states(
664
665
        self, data: MiniCPMOAudioFeatureInputs
    ) -> list[torch.Tensor]:
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        chunk_length = self.config.audio_chunk_length

        # (bs, 80, frames) or [], multi audios need filled in advance
        wavforms_raw = data["audio_features"]
        if isinstance(wavforms_raw, list):
            B = len(wavforms_raw)
            C = wavforms_raw[0].shape[-2]
            L = max(item.shape[-1] for item in wavforms_raw)
            device = wavforms_raw[0].device
            dtype = wavforms_raw[0].dtype

            wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
            for i, wavforms_item in enumerate(wavforms_raw):
                L_item = wavforms_item.shape[-1]
                wavforms[i, ..., :L_item] = wavforms_item
        else:
            wavforms = wavforms_raw
683

684
685
686
687
        # list, [[x1, x2], [y1], [z1]]
        audio_feature_lens_raw = data["audio_feature_lens"]
        if isinstance(audio_feature_lens_raw, torch.Tensor):
            audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
688

689
690
691
692
693
        audio_feature_lens = torch.hstack(audio_feature_lens_raw)
        batch_size, _, max_mel_seq_len = wavforms.shape
        max_seq_len = (max_mel_seq_len - 1) // 2 + 1

        # Create a sequence tensor of shape (batch_size, max_seq_len)
694
695
696
697
698
699
700
701
702
703
704
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feature_lens.dtype,
                device=audio_feature_lens.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
        lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
705
706
707
        # Create mask
        padding_mask = seq_range >= lengths_expand  # 1 for padded values

708
709
710
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
711
        audio_attention_mask = audio_attention_mask_.to(
712
713
            dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
        )
714
715
716
717
718
719
720
721
722
723

        if chunk_length > 0:
            chunk_num_frame = int(chunk_length * 50)
            chunk_mask = self.subsequent_chunk_mask(
                size=max_seq_len,
                chunk_size=chunk_num_frame,
                num_left_chunks=-1,
                device=audio_attention_mask_.device,
            )
            audio_attention_mask_ = torch.logical_or(
724
725
                audio_attention_mask_, torch.logical_not(chunk_mask)
            )
726
727
728

        audio_attention_mask[audio_attention_mask_] = float("-inf")
        audio_states = self.apm(
729
730
            wavforms, attention_mask=audio_attention_mask
        ).hidden_states[self.audio_encoder_layer]
731
732
733
734
735
736
        audio_embeds = self.audio_projection_layer(audio_states)

        audio_embeds = audio_embeds.transpose(1, 2)
        audio_embeds = self.audio_avg_pooler(audio_embeds)
        audio_embeds = audio_embeds.transpose(1, 2)

737
738
739
        _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
            audio_feature_lens
        )
740
741
742

        num_audio_tokens = feature_lens_after_pooling

743
        final_audio_embeds = list[torch.Tensor]()
744
745
        idx = 0
        for i in range(len(audio_feature_lens_raw)):
746
            target_audio_embeds_lst = list[torch.Tensor]()
747
            for _ in range(len(audio_feature_lens_raw[i])):
748
                target_audio_embeds_lst.append(
749
750
                    audio_embeds[idx, : num_audio_tokens[idx], :]
                )
751
752
                idx += 1

753
            final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
754

755
756
757
        return final_audio_embeds

    def _parse_and_validate_audio_input(
758
        self, **kwargs: object
759
    ) -> MiniCPMOAudioInputs | None:
760
        audio_features = kwargs.pop("audio_features", None)
761
        audio_embeds = kwargs.pop("audio_embeds", None)
762
763
764
765

        if audio_features is None and audio_embeds is None:
            return None

766
767
        if audio_embeds is not None:
            return MiniCPMOAudioEmbeddingInputs(
768
                type="audio_embeds",
769
                audio_embeds=audio_embeds,
770
            )
771

772
        audio_feature_lens = kwargs.pop("audio_feature_lens")
773

774
775
        return MiniCPMOAudioFeatureInputs(
            type="audio_features",
776
777
            audio_features=audio_features,
            audio_feature_lens=audio_feature_lens,
778
        )
779
780
781
782
783
784
785

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
786
787
788
789
790
            if (
                input_key in ("audio_features", "audio_embeds")
                and "audios" not in modalities
            ):
                modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
791
792
793
794
795
796

        return modalities

    def _process_audio_input(
        self,
        audio_input: MiniCPMOAudioInputs,
797
    ) -> torch.Tensor | list[torch.Tensor]:
798
799
800
801
802
803
804
805
806
807
808
        if audio_input["type"] == "audio_embeds":
            return audio_input["audio_embeds"]

        return self.get_audio_hidden_states(audio_input)

    def _process_multimodal_inputs(self, modalities: dict):
        multimodal_embeddings = super()._process_multimodal_inputs(modalities)

        for modality in modalities:
            if modality == "audios":
                audio_input = modalities["audios"]
809
810
                audio_embeddings = self._process_audio_input(audio_input)
                multimodal_embeddings += tuple(audio_embeddings)
811
812

        return multimodal_embeddings
tc-mb's avatar
tc-mb committed
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891


class MiniCPMO2_6(MiniCPMOBaseModel, MiniCPMV2_6):
    """MiniCPM-O 2.6 model with Qwen2 backbone."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )


class MiniCPMO4_5(MiniCPMOBaseModel, MiniCPMV4_5):
    """MiniCPM-O 4.5 model with Qwen3 backbone."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )


_MINICPMO_SUPPORT_VERSION = {
    (2, 6): MiniCPMO2_6,
    (4, 5): MiniCPMO4_5,
}


@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMOMultiModalProcessor,
    info=MiniCPMOProcessingInfo,
    dummy_inputs=MiniCPMODummyInputsBuilder,
)
class MiniCPMO(MiniCPMOBaseModel, MiniCPMV2_6):
    """
    MiniCPM-O model with audio support.
    Different versions use different LLM backbones:
    - Version 2.6: Uses Qwen2
    - Version 4.5: Uses Qwen3
    """

    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config

        # Determine version from config
        if hasattr(config, "version"):
            try:
                version_str = str(config.version)
                version_parts = version_str.split(".")
                version = tuple(int(x) for x in version_parts[:2])
            except (ValueError, TypeError) as e:
                raise ValueError(
                    f"Invalid model version format in config: {config.version}. "
                    "Expected a dot-separated version string like '4.5'."
                ) from e
        else:
            # Default to 2.6 for backward compatibility
            version = (2, 6)

        # Dispatch class based on version
        instance_cls = _MINICPMO_SUPPORT_VERSION.get(version)
        if instance_cls is None:
            supported_versions = ", ".join(
                [f"{v[0]}.{v[1]}" for v in sorted(_MINICPMO_SUPPORT_VERSION.keys())]
            )
            raise ValueError(
                f"Currently, MiniCPMO only supports versions "
                f"{supported_versions}. Got version: {version}"
            )

        return instance_cls(vllm_config=vllm_config, prefix=prefix)

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # This __init__ won't be called due to __new__ returning a different class
        pass