minicpmo.py 29.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
26

27
28
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
29
30
31

import torch
from torch import nn
32
from transformers import BatchFeature
33
from transformers.modeling_outputs import BaseModelOutputWithPast
34
35
36
37
38
39
from transformers.models.whisper.modeling_whisper import (
    ACT2FN,
    WhisperAttention,
    WhisperConfig,
    WhisperEncoder,
)
40
41

from vllm.config import VllmConfig
42
from vllm.config.multimodal import BaseDummyOptions
43
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioItem,
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityData,
    ModalityDataItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
62
from vllm.utils.tensor_schema import TensorSchema, TensorShape
63

64
65
66
from .minicpmv import (
    _MAX_FRAMES_PER_VIDEO,
    MiniCPMV2_6,
tc-mb's avatar
tc-mb committed
67
    MiniCPMV4_5,
68
69
70
71
72
73
    MiniCPMVDummyInputsBuilder,
    MiniCPMVMultiModalDataParser,
    MiniCPMVMultiModalProcessor,
    MiniCPMVProcessingInfo,
    _minicpmv_field_config,
)
74
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
75
76
77
78

CPU_DEVICE = torch.device("cpu")


79
class MiniCPMOAudioFeatureInputs(TensorSchema):
80
    """
81
82
83
84
85
86
    Dimensions:
        - bns: Batch size * number of audios * number of slices
        - bn: Batch size * number of audios
        - c: Number of channels
        - l: Length
        - s: Number of slices
87
    """
88

89
    type: Literal["audio_features"] = "audio_features"
90

91
    audio_features: Annotated[
92
        torch.Tensor | list[torch.Tensor],
93
94
95
96
97
98
        TensorShape("bns", "c", "l", dynamic_dims={"l"}),
    ]
    """
    Slice here means chunk. Audio that is too long will be split into slices,
    which is the same as image. Padding is used therefore `audio_features` is 
    `torch.Tensor`.
99
100
    """

101
    audio_feature_lens: Annotated[
102
        torch.Tensor | list[torch.Tensor],
103
104
105
        TensorShape("bn", "s"),
    ]
    """
106
    This should be feature length of each audio slice, 
107
    which equals to `audio_features.shape[-1]`
108
109
110
    """


111
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
112
    """
113
114
115
116
    Dimensions:
        - bn: Batch size * number of audios
        - s: Number of slices
        - h: Hidden size (must match language model backbone)
117

118
119
    Length of each slice may vary, so pass it as a list.
    """
120

121
122
123
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
124
        torch.Tensor | list[torch.Tensor],
125
126
        TensorShape("bn", "s", "h", dynamic_dims={"s"}),
    ]
127

128

129
130
131
MiniCPMOAudioInputs: TypeAlias = (
    MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs
)
132
133


134
135
136
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        **_minicpmv_field_config(hf_inputs),
137
138
139
        audio_features=MultiModalFieldConfig.batched("audio"),
        audio_feature_lens=MultiModalFieldConfig.batched("audio"),
        audio_embeds=MultiModalFieldConfig.batched("audio"),
140
    )
141
142


143
144
145
146
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
147
148
149
150
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
151
152
153
154
155
    ) -> None:
        super().__init__(
            data,
            modality="image",
            required_fields={"audio_embeds"},
156
            fields_factory=fields_factory,
157
        )
158
159
160
161
162


class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
    def _parse_audio_data(
        self,
163
164
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
165
        if isinstance(data, dict):
166
167
            return MiniCPMOAudioEmbeddingItems(
                data,
168
                fields_factory=_minicpmo_field_config,
169
170
            )

171
172
173
174
175
176
        return super()._parse_audio_data(data)


class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
    audio_pattern = "(<audio>./</audio>)"

177
178
179
180
181
182
    def get_data_parser(self):
        return MiniCPMOMultiModalDataParser(
            target_sr=self.get_default_audio_sampling_rate(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

183
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
184
        return {**super().get_supported_mm_limits(), "audio": None}
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def get_audio_placeholder(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        hf_processor = self.get_hf_processor()

        return hf_processor.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )

200
    def get_default_audio_pool_step(self) -> int:
tc-mb's avatar
tc-mb committed
201
202
203
        hf_config = self.get_hf_config()
        # MiniCPM-o 4.5 uses pool_step=5, older versions use 2
        return getattr(hf_config, "audio_pool_step", 2)
204
205
206
207
208
209
210
211
212
213
214

    def get_default_audio_sampling_rate(self) -> int:
        return 16000

    def get_chunk_length(self) -> int:
        return self.get_hf_config().audio_chunk_length

    def get_max_audio_tokens_per_chunk(self) -> int:
        pool_step = self.get_default_audio_pool_step()
        fbank_feat_in_chunk = 100
        cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
215
        return (cnn_feat_in_chunk - pool_step) // pool_step + 1
216
217
218
219

    def get_max_audio_chunks_with_most_features(self) -> int:
        return 30

220
    def get_max_audio_tokens(self) -> int:
221
222
        num_chunks = self.get_max_audio_chunks_with_most_features()
        return self.get_max_audio_tokens_per_chunk() * num_chunks
223

224
225
    def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
        sampling_rate = self.get_default_audio_sampling_rate()
226
        num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
227
228
        return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1

229
230
231
232
233
234
235
236
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
        max_audios = mm_counts.get("audio", 0)
237

238
239
        max_image_tokens = self.get_max_image_tokens() * max_images
        max_audio_tokens = self.get_max_audio_tokens() * max_audios
240
241
242
243
244
245
        max_total_frames = self.get_max_video_frames(
            seq_len - max_image_tokens - max_audio_tokens
        )
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
246

247
        return max(max_frames_per_video, 1)
248
249


250
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
251
252
253
254
255
256
257
258
259
260
261
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        audio_prompt_texts = self.info.audio_pattern * num_audios

        return super().get_dummy_text(mm_counts) + audio_prompt_texts

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
262
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
263
    ) -> MultiModalDataDict:
264
        num_audios = mm_counts.get("audio", 0)
265
266
267
268
        audio_len = (
            self.info.get_max_audio_chunks_with_most_features()
            * self.info.get_default_audio_sampling_rate()
        )
269

270
271
        audio_overrides = mm_options.get("audio") if mm_options else None

272
        audio_mm_data = {
273
274
275
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
276
277
        }

278
        return {
279
            **super().get_dummy_mm_data(seq_len, mm_counts, mm_options),
280
281
            **audio_mm_data,
        }
282
283


284
class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
285
286
287
288
289
290
291
292
293
294
295
    def get_audio_prompt_texts(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        return self.info.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )
296

297
298
299
300
    def process_audios(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
301
        tok_kwargs: Mapping[str, object],
302
    ) -> Mapping[str, NestedTensors]:
303
304
305
        if (audios := mm_data.get("audios")) is None:
            return {}

306
307
        mm_items = self.info.parse_mm_data({"audio": audios}, validate=False)
        parsed_audios = mm_items.get_items(
308
            "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
309
        )
310

311
312
313
314
315
316
        if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
            audio_inputs = {}
        else:
            audio_inputs = self._base_call_hf_processor(
                prompts=[self.info.audio_pattern] * len(parsed_audios),
                mm_data={"audios": [[audio] for audio in parsed_audios]},
317
                mm_kwargs={**mm_kwargs, "chunk_input": True},
318
                tok_kwargs=tok_kwargs,
319
                out_keys={"audio_features", "audio_feature_lens"},
320
            )
321
322
323
324

            # Avoid padding since we need the output for each audio to be
            # independent of other audios for the cache to work correctly
            unpadded_audio_features = [
325
326
                feat[:, :feature_len]
                for feat, feature_len in zip(
327
328
329
330
331
332
333
                    audio_inputs["audio_features"],
                    audio_inputs["audio_feature_lens"],
                )
            ]
            audio_inputs["audio_features"] = unpadded_audio_features

        return audio_inputs
334

335
336
337
338
    def process_mm_inputs(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
339
        tok_kwargs: Mapping[str, object],
340
    ) -> Mapping[str, NestedTensors]:
341
        return {
342
343
            **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs),
            **self.process_audios(mm_data, mm_kwargs, tok_kwargs),
344
345
        }

346
    def _get_prompt_updates(
347
348
349
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
350
        out_mm_kwargs: MultiModalKwargsItems,
351
352
353
354
355
356
    ) -> Sequence[PromptUpdate]:
        base_updates = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
357

358
359
360
361
        audio_placeholder = self.info.audio_pattern

        def get_audio_replacement(item_idx: int):
            audios = mm_items.get_items(
362
363
                "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
            )
364
365
366
367

            if isinstance(audios, MiniCPMOAudioEmbeddingItems):
                single_audio_embeds = audios.get(item_idx)["audio_embeds"]
                audio_len = self.info.get_audio_len_by_num_chunks(
368
369
                    sum(map(len, single_audio_embeds))
                )
370
371
372
            else:
                audio_len = audios.get_audio_length(item_idx)

373
374
375
376
            return PromptUpdateDetails.select_text(
                self.get_audio_prompt_texts(audio_len),
                "<unk>",
            )
377
378

        return [
379
            *base_updates,
380
381
382
383
384
            PromptReplacement(
                modality="audio",
                target=audio_placeholder,
                replacement=get_audio_replacement,
            ),
385
386
387
388
        ]

    def _get_mm_fields_config(
        self,
389
        hf_inputs: BatchFeature,
390
391
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
392
        return _minicpmo_field_config(hf_inputs)
393
394
395
396
397


class MultiModalProjector(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
398
        self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
399
        self.relu = nn.ReLU()
400
        self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
401
402
403
404
405
406
407
408

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.relu(self.linear1(audio_features))
        hidden_states = self.linear2(hidden_states)
        return hidden_states


class MiniCPMWhisperEncoderLayer(nn.Module):
409
    def __init__(self, config: WhisperConfig, layer_idx: int):
410
411
        super().__init__()
        self.embed_dim = config.d_model
412
413
414
415
416
417
418
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
            layer_idx=layer_idx,
        )
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
434
        hidden_states, _ = self.self_attn(
435
436
437
            hidden_states=hidden_states,
            attention_mask=attention_mask,
        )
438
439
440
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
441
442
443
444
445
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
446
447
448
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.activation_dropout, training=self.training
        )
449
        hidden_states = self.fc2(hidden_states)
450
451
452
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
453
454
        hidden_states = residual + hidden_states

455
456
        if hidden_states.dtype == torch.float16:
            hidden_states = cast_overflow_tensors(hidden_states)
457

458
        outputs = (hidden_states,)
459
460
461
462
463
464
465

        return outputs


class MiniCPMWhisperEncoder(WhisperEncoder):
    def __init__(self, config: WhisperConfig):
        super().__init__(config)
466
467
468
469
470
471
        self.layers = nn.ModuleList(
            [
                MiniCPMWhisperEncoderLayer(config, layer_idx=i)
                for i in range(config.encoder_layers)
            ]
        )
472
473
474
475

    def forward(
        self,
        input_features: torch.Tensor,
476
        attention_mask: torch.Tensor | None = None,
477
478
    ) -> BaseModelOutputWithPast:
        # Ignore copy
479
480
481
        input_features = input_features.to(
            dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
        )
482
483
484
485
486
487
488
489

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)

        embed_pos = self.embed_positions.weight

490
        embed_pos = embed_pos[: inputs_embeds.shape[1], :]
491
492

        hidden_states = inputs_embeds + embed_pos
493
494
495
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
496
497
498
499

        encoder_states = ()

        for idx, encoder_layer in enumerate(self.layers):
500
            encoder_states = encoder_states + (hidden_states,)
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            # Ignore copy
            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                )

                hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
519
        encoder_states = encoder_states + (hidden_states,)
520
521
522
523
524
525
526

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )


tc-mb's avatar
tc-mb committed
527
528
529
class MiniCPMOBaseModel:
    """Base mixin class for MiniCPM-O models with audio support."""

530
531
532
533
534
535
536
537
538
539
540
541
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

542
    @classmethod
543
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
544
545
546
547
548
549
550
551
552
        if modality.startswith("image"):
            return "(<image>./</image>)"
        if modality.startswith("video"):
            return "(<video>./</video>)"
        if modality.startswith("audio"):
            return "(<audio>./</audio>)"

        raise ValueError("Only image, video or audio modality is supported")

553
554
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
555
556
557
558
559

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )
560
561
562
563
564
565

    def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Do not use parameters temporarily
        audio_config = self.config.audio_config
        model = MiniCPMWhisperEncoder(audio_config)
        audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
566
567
568
569
570
571
        self.audio_avg_pooler = nn.AvgPool1d(
            self.config.audio_pool_step, stride=self.config.audio_pool_step
        )
        self.audio_projection_layer = MultiModalProjector(
            in_dim=audio_output_dim, out_dim=self.embed_dim
        )
572
573
574
        self.audio_encoder_layer = -1
        return model

575
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
576
577
578
579
580
581
582
583
584
585
586
587
        loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
        return loader.load_weights(weights)

    def subsequent_chunk_mask(
        self,
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = CPU_DEVICE,
        num_lookhead: int = 0,
    ) -> torch.Tensor:
        ret = torch.zeros(size, size, device=device, dtype=torch.bool)
588
589
        # Vectorized computation of row indices and chunk boundaries
        row_indices = torch.arange(size, device=device)
Cyrus Leung's avatar
Cyrus Leung committed
590
        chunk_indices = row_indices // chunk_size
591
592
593
594
595
        if num_left_chunks < 0:
            # If num_left_chunks < 0, start is always 0 for all rows
            start_indices = torch.zeros_like(row_indices)
        else:
            # Compute start indices vectorially
596
            start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0)
Cyrus Leung's avatar
Cyrus Leung committed
597
            start_indices = start_chunk_indices * chunk_size
598
599
        # Compute ending indices vectorially
        end_chunk_indices = chunk_indices + 1
600
601
602
        end_indices = torch.clamp(
            end_chunk_indices * chunk_size + num_lookhead, max=size
        )
603
604
605
606
607
608
        # Create column indices for broadcasting
        col_indices = torch.arange(size, device=device).unsqueeze(0)
        start_indices = start_indices.unsqueeze(1)
        end_indices = end_indices.unsqueeze(1)
        # Vectorized mask creation
        ret = (col_indices >= start_indices) & (col_indices < end_indices)
609
610
        return ret

611
    def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
612
613
        input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
        input_lengths_after_pooling = (
614
615
616
            input_lengths_after_cnn - self.config.audio_pool_step
        ) // self.config.audio_pool_step + 1
        input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
617
618
619

        return input_lengths_after_cnn, input_lengths_after_pooling

620
    def get_audio_hidden_states(
621
622
        self, data: MiniCPMOAudioFeatureInputs
    ) -> list[torch.Tensor]:
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        chunk_length = self.config.audio_chunk_length

        # (bs, 80, frames) or [], multi audios need filled in advance
        wavforms_raw = data["audio_features"]
        if isinstance(wavforms_raw, list):
            B = len(wavforms_raw)
            C = wavforms_raw[0].shape[-2]
            L = max(item.shape[-1] for item in wavforms_raw)
            device = wavforms_raw[0].device
            dtype = wavforms_raw[0].dtype

            wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
            for i, wavforms_item in enumerate(wavforms_raw):
                L_item = wavforms_item.shape[-1]
                wavforms[i, ..., :L_item] = wavforms_item
        else:
            wavforms = wavforms_raw
640

641
642
643
644
        # list, [[x1, x2], [y1], [z1]]
        audio_feature_lens_raw = data["audio_feature_lens"]
        if isinstance(audio_feature_lens_raw, torch.Tensor):
            audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
645

646
647
648
649
650
        audio_feature_lens = torch.hstack(audio_feature_lens_raw)
        batch_size, _, max_mel_seq_len = wavforms.shape
        max_seq_len = (max_mel_seq_len - 1) // 2 + 1

        # Create a sequence tensor of shape (batch_size, max_seq_len)
651
652
653
654
655
656
657
658
659
660
661
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feature_lens.dtype,
                device=audio_feature_lens.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
        lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
662
663
664
        # Create mask
        padding_mask = seq_range >= lengths_expand  # 1 for padded values

665
666
667
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
668
        audio_attention_mask = audio_attention_mask_.to(
669
670
            dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
        )
671
672
673
674
675
676
677
678
679
680

        if chunk_length > 0:
            chunk_num_frame = int(chunk_length * 50)
            chunk_mask = self.subsequent_chunk_mask(
                size=max_seq_len,
                chunk_size=chunk_num_frame,
                num_left_chunks=-1,
                device=audio_attention_mask_.device,
            )
            audio_attention_mask_ = torch.logical_or(
681
682
                audio_attention_mask_, torch.logical_not(chunk_mask)
            )
683
684
685

        audio_attention_mask[audio_attention_mask_] = float("-inf")
        audio_states = self.apm(
686
687
            wavforms, attention_mask=audio_attention_mask
        ).hidden_states[self.audio_encoder_layer]
688
689
690
691
692
693
        audio_embeds = self.audio_projection_layer(audio_states)

        audio_embeds = audio_embeds.transpose(1, 2)
        audio_embeds = self.audio_avg_pooler(audio_embeds)
        audio_embeds = audio_embeds.transpose(1, 2)

694
695
696
        _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
            audio_feature_lens
        )
697
698
699

        num_audio_tokens = feature_lens_after_pooling

700
        final_audio_embeds = list[torch.Tensor]()
701
702
        idx = 0
        for i in range(len(audio_feature_lens_raw)):
703
            target_audio_embeds_lst = list[torch.Tensor]()
704
            for _ in range(len(audio_feature_lens_raw[i])):
705
                target_audio_embeds_lst.append(
706
707
                    audio_embeds[idx, : num_audio_tokens[idx], :]
                )
708
709
                idx += 1

710
            final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
711

712
713
714
        return final_audio_embeds

    def _parse_and_validate_audio_input(
715
        self, **kwargs: object
716
    ) -> MiniCPMOAudioInputs | None:
717
        audio_features = kwargs.pop("audio_features", None)
718
        audio_embeds = kwargs.pop("audio_embeds", None)
719
720
721
722

        if audio_features is None and audio_embeds is None:
            return None

723
724
        if audio_embeds is not None:
            return MiniCPMOAudioEmbeddingInputs(
725
                type="audio_embeds",
726
                audio_embeds=audio_embeds,
727
            )
728

729
        audio_feature_lens = kwargs.pop("audio_feature_lens")
730

731
732
        return MiniCPMOAudioFeatureInputs(
            type="audio_features",
733
734
            audio_features=audio_features,
            audio_feature_lens=audio_feature_lens,
735
        )
736
737
738
739
740
741
742

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
743
744
745
746
747
            if (
                input_key in ("audio_features", "audio_embeds")
                and "audios" not in modalities
            ):
                modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
748
749
750
751
752
753

        return modalities

    def _process_audio_input(
        self,
        audio_input: MiniCPMOAudioInputs,
754
    ) -> torch.Tensor | list[torch.Tensor]:
755
756
757
758
759
760
761
762
763
764
765
        if audio_input["type"] == "audio_embeds":
            return audio_input["audio_embeds"]

        return self.get_audio_hidden_states(audio_input)

    def _process_multimodal_inputs(self, modalities: dict):
        multimodal_embeddings = super()._process_multimodal_inputs(modalities)

        for modality in modalities:
            if modality == "audios":
                audio_input = modalities["audios"]
766
767
                audio_embeddings = self._process_audio_input(audio_input)
                multimodal_embeddings += tuple(audio_embeddings)
768
769

        return multimodal_embeddings
tc-mb's avatar
tc-mb committed
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848


class MiniCPMO2_6(MiniCPMOBaseModel, MiniCPMV2_6):
    """MiniCPM-O 2.6 model with Qwen2 backbone."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )


class MiniCPMO4_5(MiniCPMOBaseModel, MiniCPMV4_5):
    """MiniCPM-O 4.5 model with Qwen3 backbone."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        with self._mark_tower_model(vllm_config, "audio"):
            self.apm = self.init_audio_module(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
            )


_MINICPMO_SUPPORT_VERSION = {
    (2, 6): MiniCPMO2_6,
    (4, 5): MiniCPMO4_5,
}


@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMOMultiModalProcessor,
    info=MiniCPMOProcessingInfo,
    dummy_inputs=MiniCPMODummyInputsBuilder,
)
class MiniCPMO(MiniCPMOBaseModel, MiniCPMV2_6):
    """
    MiniCPM-O model with audio support.
    Different versions use different LLM backbones:
    - Version 2.6: Uses Qwen2
    - Version 4.5: Uses Qwen3
    """

    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config

        # Determine version from config
        if hasattr(config, "version"):
            try:
                version_str = str(config.version)
                version_parts = version_str.split(".")
                version = tuple(int(x) for x in version_parts[:2])
            except (ValueError, TypeError) as e:
                raise ValueError(
                    f"Invalid model version format in config: {config.version}. "
                    "Expected a dot-separated version string like '4.5'."
                ) from e
        else:
            # Default to 2.6 for backward compatibility
            version = (2, 6)

        # Dispatch class based on version
        instance_cls = _MINICPMO_SUPPORT_VERSION.get(version)
        if instance_cls is None:
            supported_versions = ", ".join(
                [f"{v[0]}.{v[1]}" for v in sorted(_MINICPMO_SUPPORT_VERSION.keys())]
            )
            raise ValueError(
                f"Currently, MiniCPMO only supports versions "
                f"{supported_versions}. Got version: {version}"
            )

        return instance_cls(vllm_config=vllm_config, prefix=prefix)

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # This __init__ won't be called due to __new__ returning a different class
        pass