minicpmo.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
26

27
28
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
29
30
31

import torch
from torch import nn
32
from transformers import BatchFeature
33
from transformers.modeling_outputs import BaseModelOutputWithPast
34
35
36
37
38
39
from transformers.models.whisper.modeling_whisper import (
    ACT2FN,
    WhisperAttention,
    WhisperConfig,
    WhisperEncoder,
)
40
41

from vllm.config import VllmConfig
42
from vllm.config.multimodal import BaseDummyOptions
43
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioItem,
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityData,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
63
from vllm.utils.tensor_schema import TensorSchema, TensorShape
64

65
66
67
68
69
70
71
72
73
74
from .minicpmv import (
    _MAX_FRAMES_PER_VIDEO,
    MiniCPMV2_6,
    MiniCPMVDummyInputsBuilder,
    MiniCPMVMultiModalDataParser,
    MiniCPMVMultiModalProcessor,
    MiniCPMVProcessingInfo,
    _minicpmv_field_config,
)
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
75
76
77
78

CPU_DEVICE = torch.device("cpu")


79
class MiniCPMOAudioFeatureInputs(TensorSchema):
80
    """
81
82
83
84
85
86
    Dimensions:
        - bns: Batch size * number of audios * number of slices
        - bn: Batch size * number of audios
        - c: Number of channels
        - l: Length
        - s: Number of slices
87
    """
88

89
    type: Literal["audio_features"] = "audio_features"
90

91
    audio_features: Annotated[
92
        torch.Tensor | list[torch.Tensor],
93
94
95
96
97
98
        TensorShape("bns", "c", "l", dynamic_dims={"l"}),
    ]
    """
    Slice here means chunk. Audio that is too long will be split into slices,
    which is the same as image. Padding is used therefore `audio_features` is 
    `torch.Tensor`.
99
100
    """

101
    audio_feature_lens: Annotated[
102
        torch.Tensor | list[torch.Tensor],
103
104
105
        TensorShape("bn", "s"),
    ]
    """
106
    This should be feature length of each audio slice, 
107
    which equals to `audio_features.shape[-1]`
108
109
110
    """


111
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
112
    """
113
114
115
116
    Dimensions:
        - bn: Batch size * number of audios
        - s: Number of slices
        - h: Hidden size (must match language model backbone)
117

118
119
    Length of each slice may vary, so pass it as a list.
    """
120

121
122
123
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
124
        torch.Tensor | list[torch.Tensor],
125
126
        TensorShape("bn", "s", "h", dynamic_dims={"s"}),
    ]
127

128

129
130
131
MiniCPMOAudioInputs: TypeAlias = (
    MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs
)
132
133


134
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
135
136
137
    audio_features = hf_inputs.get("audio_features", torch.empty(0))
    num_audios = len(audio_features)

138
139
    return dict(
        **_minicpmv_field_config(hf_inputs),
140
141
142
        audio_features=MultiModalFieldConfig.batched("audio"),
        audio_feature_lens=MultiModalFieldConfig.batched("audio"),
        audio_embeds=MultiModalFieldConfig.batched("audio"),
143
        audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
144
    )
145
146


147
148
149
150
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
151
152
153
154
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
155
156
157
158
159
    ) -> None:
        super().__init__(
            data,
            modality="image",
            required_fields={"audio_embeds"},
160
            fields_factory=fields_factory,
161
        )
162
163
164
165
166


class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
    def _parse_audio_data(
        self,
167
168
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
169
        if isinstance(data, dict):
170
171
            return MiniCPMOAudioEmbeddingItems(
                data,
172
                fields_factory=_minicpmo_field_config,
173
174
            )

175
176
177
178
179
180
        return super()._parse_audio_data(data)


class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
    audio_pattern = "(<audio>./</audio>)"

181
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
182
        return {**super().get_supported_mm_limits(), "audio": None}
183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def get_audio_placeholder(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        hf_processor = self.get_hf_processor()

        return hf_processor.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )

198
199
200
201
202
203
204
205
206
207
208
209
210
    def get_default_audio_pool_step(self) -> int:
        return 2

    def get_default_audio_sampling_rate(self) -> int:
        return 16000

    def get_chunk_length(self) -> int:
        return self.get_hf_config().audio_chunk_length

    def get_max_audio_tokens_per_chunk(self) -> int:
        pool_step = self.get_default_audio_pool_step()
        fbank_feat_in_chunk = 100
        cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
211
        return (cnn_feat_in_chunk - pool_step) // pool_step + 1
212
213
214
215

    def get_max_audio_chunks_with_most_features(self) -> int:
        return 30

216
    def get_max_audio_tokens(self) -> int:
217
218
        num_chunks = self.get_max_audio_chunks_with_most_features()
        return self.get_max_audio_tokens_per_chunk() * num_chunks
219

220
221
    def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
        sampling_rate = self.get_default_audio_sampling_rate()
222
        num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
223
224
        return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1

225
226
227
228
229
230
231
232
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
        max_audios = mm_counts.get("audio", 0)
233

234
235
        max_image_tokens = self.get_max_image_tokens() * max_images
        max_audio_tokens = self.get_max_audio_tokens() * max_audios
236
237
238
239
240
241
        max_total_frames = self.get_max_video_frames(
            seq_len - max_image_tokens - max_audio_tokens
        )
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
242

243
        return max(max_frames_per_video, 1)
244
245


246
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
247
248
249
250
251
252
253
254
255
256
257
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        audio_prompt_texts = self.info.audio_pattern * num_audios

        return super().get_dummy_text(mm_counts) + audio_prompt_texts

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
258
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
259
    ) -> MultiModalDataDict:
260
        num_audios = mm_counts.get("audio", 0)
261
262
263
264
        audio_len = (
            self.info.get_max_audio_chunks_with_most_features()
            * self.info.get_default_audio_sampling_rate()
        )
265

266
267
        audio_overrides = mm_options.get("audio") if mm_options else None

268
        audio_mm_data = {
269
270
271
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
272
273
        }

274
        return {
275
            **super().get_dummy_mm_data(seq_len, mm_counts, mm_options),
276
277
            **audio_mm_data,
        }
278
279


280
class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
281
282
    def _get_data_parser(self) -> MultiModalDataParser:
        return MiniCPMOMultiModalDataParser(
283
284
            target_sr=self.info.get_default_audio_sampling_rate()
        )
285

286
287
288
289
290
291
292
293
294
295
296
    def get_audio_prompt_texts(
        self,
        audio_lens: int,
        chunk_input: bool = True,
        chunk_length: int = 1,
    ) -> str:
        return self.info.get_audio_placeholder(
            audio_lens,
            chunk_input=chunk_input,
            chunk_length=chunk_length,
        )
297

298
299
300
301
    def process_audios(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
302
        tok_kwargs: Mapping[str, object],
303
    ) -> Mapping[str, NestedTensors]:
304
305
306
        if (audios := mm_data.get("audios")) is None:
            return {}

307
308
309
310
311
        parsed_audios = (
            self._get_data_parser()
            .parse_mm_data({"audio": audios})
            .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
        )
312

313
314
315
316
317
318
        if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
            audio_inputs = {}
        else:
            audio_inputs = self._base_call_hf_processor(
                prompts=[self.info.audio_pattern] * len(parsed_audios),
                mm_data={"audios": [[audio] for audio in parsed_audios]},
319
                mm_kwargs={**mm_kwargs, "chunk_input": True},
320
                tok_kwargs=tok_kwargs,
321
                out_keys={"audio_features", "audio_feature_lens"},
322
            )
323
324
325
326

            # Avoid padding since we need the output for each audio to be
            # independent of other audios for the cache to work correctly
            unpadded_audio_features = [
327
328
                feat[:, :feature_len]
                for feat, feature_len in zip(
329
330
331
332
333
334
335
336
337
338
339
                    audio_inputs["audio_features"],
                    audio_inputs["audio_feature_lens"],
                )
            ]
            audio_inputs["audio_features"] = unpadded_audio_features

        tokenizer = self.info.get_tokenizer()
        unk_token_id = tokenizer.get_vocab()["<unk>"]
        audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)

        return audio_inputs
340

341
342
343
344
    def process_mm_inputs(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
345
        tok_kwargs: Mapping[str, object],
346
    ) -> Mapping[str, NestedTensors]:
347
        return {
348
349
            **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs),
            **self.process_audios(mm_data, mm_kwargs, tok_kwargs),
350
351
        }

352
    def _get_prompt_updates(
353
354
355
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
356
        out_mm_kwargs: MultiModalKwargsItems,
357
358
359
360
361
362
    ) -> Sequence[PromptUpdate]:
        base_updates = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
363

364
365
366
367
        audio_placeholder = self.info.audio_pattern

        def get_audio_replacement(item_idx: int):
            audios = mm_items.get_items(
368
369
                "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
            )
370
371
372
373

            if isinstance(audios, MiniCPMOAudioEmbeddingItems):
                single_audio_embeds = audios.get(item_idx)["audio_embeds"]
                audio_len = self.info.get_audio_len_by_num_chunks(
374
375
                    sum(map(len, single_audio_embeds))
                )
376
377
378
            else:
                audio_len = audios.get_audio_length(item_idx)

379
380
381
382
            return PromptUpdateDetails.select_text(
                self.get_audio_prompt_texts(audio_len),
                "<unk>",
            )
383
384

        return [
385
            *base_updates,
386
387
388
389
390
            PromptReplacement(
                modality="audio",
                target=audio_placeholder,
                replacement=get_audio_replacement,
            ),
391
392
393
394
        ]

    def _get_mm_fields_config(
        self,
395
        hf_inputs: BatchFeature,
396
397
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
398
        return _minicpmo_field_config(hf_inputs)
399
400
401
402
403


class MultiModalProjector(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
404
        self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
405
        self.relu = nn.ReLU()
406
        self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
407
408
409
410
411
412
413
414

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.relu(self.linear1(audio_features))
        hidden_states = self.linear2(hidden_states)
        return hidden_states


class MiniCPMWhisperEncoderLayer(nn.Module):
415
    def __init__(self, config: WhisperConfig, layer_idx: int):
416
417
        super().__init__()
        self.embed_dim = config.d_model
418
419
420
421
422
423
424
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
            layer_idx=layer_idx,
        )
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        past_key_values = None
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states, attn_weights, past_key_values = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            past_key_value=past_key_values,
        )
446
447
448
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
449
450
451
452
453
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
454
455
456
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.activation_dropout, training=self.training
        )
457
        hidden_states = self.fc2(hidden_states)
458
459
460
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
461
462
        hidden_states = residual + hidden_states

463
464
        if hidden_states.dtype == torch.float16:
            hidden_states = cast_overflow_tensors(hidden_states)
465

466
        outputs = (hidden_states,)
467
468
469
470
471
472
473

        return outputs


class MiniCPMWhisperEncoder(WhisperEncoder):
    def __init__(self, config: WhisperConfig):
        super().__init__(config)
474
475
476
477
478
479
        self.layers = nn.ModuleList(
            [
                MiniCPMWhisperEncoderLayer(config, layer_idx=i)
                for i in range(config.encoder_layers)
            ]
        )
480
481
482
483

    def forward(
        self,
        input_features: torch.Tensor,
484
        attention_mask: torch.Tensor | None = None,
485
486
    ) -> BaseModelOutputWithPast:
        # Ignore copy
487
488
489
        input_features = input_features.to(
            dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
        )
490
491
492
493
494
495
496
497

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)

        embed_pos = self.embed_positions.weight

498
        embed_pos = embed_pos[: inputs_embeds.shape[1], :]
499
500

        hidden_states = inputs_embeds + embed_pos
501
502
503
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
504
505
506
507

        encoder_states = ()

        for idx, encoder_layer in enumerate(self.layers):
508
            encoder_states = encoder_states + (hidden_states,)
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            # Ignore copy
            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                )

                hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
527
        encoder_states = encoder_states + (hidden_states,)
528
529
530
531
532
533
534
535
536
537

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )


@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMOMultiModalProcessor,
    info=MiniCPMOProcessingInfo,
538
539
    dummy_inputs=MiniCPMODummyInputsBuilder,
)
540
541
542
543
544
545
546
547
548
549
550
551
552
class MiniCPMO(MiniCPMV2_6):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

553
    @classmethod
554
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
555
556
557
558
559
560
561
562
563
        if modality.startswith("image"):
            return "(<image>./</image>)"
        if modality.startswith("video"):
            return "(<video>./</video>)"
        if modality.startswith("audio"):
            return "(<audio>./</audio>)"

        raise ValueError("Only image, video or audio modality is supported")

564
565
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
566
567
568
        self.apm = self.init_audio_module(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
        )
569

570
571
        self.audio_token_id = None

572
573
574
575
576
    def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Do not use parameters temporarily
        audio_config = self.config.audio_config
        model = MiniCPMWhisperEncoder(audio_config)
        audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
577
578
579
580
581
582
        self.audio_avg_pooler = nn.AvgPool1d(
            self.config.audio_pool_step, stride=self.config.audio_pool_step
        )
        self.audio_projection_layer = MultiModalProjector(
            in_dim=audio_output_dim, out_dim=self.embed_dim
        )
583
584
585
        self.audio_encoder_layer = -1
        return model

586
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
587
588
589
590
591
592
593
594
595
596
597
598
        loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
        return loader.load_weights(weights)

    def subsequent_chunk_mask(
        self,
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = CPU_DEVICE,
        num_lookhead: int = 0,
    ) -> torch.Tensor:
        ret = torch.zeros(size, size, device=device, dtype=torch.bool)
599
600
        # Vectorized computation of row indices and chunk boundaries
        row_indices = torch.arange(size, device=device)
Cyrus Leung's avatar
Cyrus Leung committed
601
        chunk_indices = row_indices // chunk_size
602
603
604
605
606
        if num_left_chunks < 0:
            # If num_left_chunks < 0, start is always 0 for all rows
            start_indices = torch.zeros_like(row_indices)
        else:
            # Compute start indices vectorially
607
            start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0)
Cyrus Leung's avatar
Cyrus Leung committed
608
            start_indices = start_chunk_indices * chunk_size
609
610
        # Compute ending indices vectorially
        end_chunk_indices = chunk_indices + 1
611
612
613
        end_indices = torch.clamp(
            end_chunk_indices * chunk_size + num_lookhead, max=size
        )
614
615
616
617
618
619
        # Create column indices for broadcasting
        col_indices = torch.arange(size, device=device).unsqueeze(0)
        start_indices = start_indices.unsqueeze(1)
        end_indices = end_indices.unsqueeze(1)
        # Vectorized mask creation
        ret = (col_indices >= start_indices) & (col_indices < end_indices)
620
621
        return ret

622
    def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
623
624
        input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
        input_lengths_after_pooling = (
625
626
627
            input_lengths_after_cnn - self.config.audio_pool_step
        ) // self.config.audio_pool_step + 1
        input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
628
629
630

        return input_lengths_after_cnn, input_lengths_after_pooling

631
    def get_audio_hidden_states(
632
633
        self, data: MiniCPMOAudioFeatureInputs
    ) -> list[torch.Tensor]:
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        chunk_length = self.config.audio_chunk_length

        # (bs, 80, frames) or [], multi audios need filled in advance
        wavforms_raw = data["audio_features"]
        if isinstance(wavforms_raw, list):
            B = len(wavforms_raw)
            C = wavforms_raw[0].shape[-2]
            L = max(item.shape[-1] for item in wavforms_raw)
            device = wavforms_raw[0].device
            dtype = wavforms_raw[0].dtype

            wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
            for i, wavforms_item in enumerate(wavforms_raw):
                L_item = wavforms_item.shape[-1]
                wavforms[i, ..., :L_item] = wavforms_item
        else:
            wavforms = wavforms_raw
651

652
653
654
655
        # list, [[x1, x2], [y1], [z1]]
        audio_feature_lens_raw = data["audio_feature_lens"]
        if isinstance(audio_feature_lens_raw, torch.Tensor):
            audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
656

657
658
659
660
661
        audio_feature_lens = torch.hstack(audio_feature_lens_raw)
        batch_size, _, max_mel_seq_len = wavforms.shape
        max_seq_len = (max_mel_seq_len - 1) // 2 + 1

        # Create a sequence tensor of shape (batch_size, max_seq_len)
662
663
664
665
666
667
668
669
670
671
672
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feature_lens.dtype,
                device=audio_feature_lens.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
        lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
673
674
675
        # Create mask
        padding_mask = seq_range >= lengths_expand  # 1 for padded values

676
677
678
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
679
        audio_attention_mask = audio_attention_mask_.to(
680
681
            dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
        )
682
683
684
685
686
687
688
689
690
691

        if chunk_length > 0:
            chunk_num_frame = int(chunk_length * 50)
            chunk_mask = self.subsequent_chunk_mask(
                size=max_seq_len,
                chunk_size=chunk_num_frame,
                num_left_chunks=-1,
                device=audio_attention_mask_.device,
            )
            audio_attention_mask_ = torch.logical_or(
692
693
                audio_attention_mask_, torch.logical_not(chunk_mask)
            )
694
695
696

        audio_attention_mask[audio_attention_mask_] = float("-inf")
        audio_states = self.apm(
697
698
            wavforms, attention_mask=audio_attention_mask
        ).hidden_states[self.audio_encoder_layer]
699
700
701
702
703
704
        audio_embeds = self.audio_projection_layer(audio_states)

        audio_embeds = audio_embeds.transpose(1, 2)
        audio_embeds = self.audio_avg_pooler(audio_embeds)
        audio_embeds = audio_embeds.transpose(1, 2)

705
706
707
        _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
            audio_feature_lens
        )
708
709
710

        num_audio_tokens = feature_lens_after_pooling

711
        final_audio_embeds = list[torch.Tensor]()
712
713
        idx = 0
        for i in range(len(audio_feature_lens_raw)):
714
            target_audio_embeds_lst = list[torch.Tensor]()
715
            for _ in range(len(audio_feature_lens_raw[i])):
716
                target_audio_embeds_lst.append(
717
718
                    audio_embeds[idx, : num_audio_tokens[idx], :]
                )
719
720
                idx += 1

721
            final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
722

723
724
725
        return final_audio_embeds

    def _parse_and_validate_audio_input(
726
        self, **kwargs: object
727
    ) -> MiniCPMOAudioInputs | None:
728
        audio_features = kwargs.pop("audio_features", None)
729
        audio_embeds = kwargs.pop("audio_embeds", None)
730
731
732
733

        if audio_features is None and audio_embeds is None:
            return None

734
735
736
737
738
        audio_token_id = kwargs.pop("audio_token_id")
        if audio_token_id is not None:
            assert isinstance(audio_token_id, torch.Tensor)
            self.mm_token_ids.add(audio_token_id.flatten().unique().item())

739
        if audio_embeds is not None:
740
            if not isinstance(audio_embeds, (torch.Tensor, list)):
741
742
743
                raise ValueError(
                    f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
                )
744

745
746
            audio_embeds_flat = flatten_bn(audio_embeds)

747
            return MiniCPMOAudioEmbeddingInputs(
748
                type="audio_embeds",
749
                audio_embeds=audio_embeds_flat,
750
751
            )

752
        if not isinstance(audio_features, (torch.Tensor, list)):
753
754
755
            raise ValueError(
                f"Incorrect type of audio_features. Got type: {type(audio_features)}"
            )
756

757
758
        audio_feature_lens = kwargs.pop("audio_feature_lens")
        if not isinstance(audio_feature_lens, (torch.Tensor, list)):
759
760
761
762
            raise ValueError(
                "Incorrect type of audio_feature_lens. "
                f"Got type: {type(audio_feature_lens)}"
            )
763

764
765
        audio_features_flat = flatten_bn(audio_features)
        audio_feature_lens_flat = flatten_bn(audio_feature_lens)
766

767
768
769
770
        return MiniCPMOAudioFeatureInputs(
            type="audio_features",
            audio_features=audio_features_flat,
            audio_feature_lens=audio_feature_lens_flat,
771
        )
772
773
774
775
776
777
778

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
779
780
781
782
783
            if (
                input_key in ("audio_features", "audio_embeds")
                and "audios" not in modalities
            ):
                modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
784
785
786
787
788
789

        return modalities

    def _process_audio_input(
        self,
        audio_input: MiniCPMOAudioInputs,
790
    ) -> torch.Tensor | list[torch.Tensor]:
791
792
793
794
795
796
797
798
799
800
801
802
        if audio_input["type"] == "audio_embeds":
            return audio_input["audio_embeds"]

        return self.get_audio_hidden_states(audio_input)

    def _process_multimodal_inputs(self, modalities: dict):
        multimodal_embeddings = super()._process_multimodal_inputs(modalities)

        for modality in modalities:
            if modality == "audios":
                audio_input = modalities["audios"]
                audio_features = self._process_audio_input(audio_input)
803
                multimodal_embeddings += tuple(audio_features)
804
805

        return multimodal_embeddings