minicpmo.py 32.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                    Tuple, TypedDict, Union)

import torch
import torch.types
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import (
    ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ModalityData, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser,
                                   VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        PromptReplacement)
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors

from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
                       MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser,
                       MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo)
from .utils import AutoWeightsLoader, maybe_prefix

CPU_DEVICE = torch.device("cpu")

MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems


class MiniCPMOAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
    Slice here means chunk. Audio that is too long will be split into slices,
    which is the same as image.
    Padding is used therefore `data` is `torch.Tensor`.
    """

    audio_feature_lens: torch.Tensor
    """
    Shape: `(batch_size * num_audios * num_slices)`

    This should be feature length of each audio slice, 
    which equals to `data.shape[-1]`
    """

    audio_bounds: torch.Tensor
    """
    Shape: `(batch_size * num_audios * num_slices, 2)`

    This should be in `(start, stop)` format.
    """


class MiniCPMOAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
    data: List[torch.Tensor]
    """
    Shape: `(batch_size * num_images * num_slices, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    instead of a batched tensor.
    Length of each slice may vary, so pass it as a list.
    """
    audio_bounds: torch.Tensor
    """
    Shape: `(batch_size * num_audios * num_slices, 2)`

    This should be in `(start, stop)` format.
    """


MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
                            MiniCPMOAudioEmbeddingInputs]


class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems):

    def __init__(self, data: Dict) -> None:
        super().__init__(data, "audio")
        audio_embeds = self.data.get("audio_embeds", None)
        if audio_embeds is None:
            raise ValueError("Incorrect type of video_embeds",
                             "Got type: None")
        self.data["audio_embeds"] = audio_embeds

    def get(self, index: int) -> object:
        return self.data["audio_embeds"][index]


class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):

    def _parse_audio_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return MiniCPMOAudioEmbeddingItems(data)
        return super()._parse_audio_data(data)


class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
    audio_pattern = "(<audio>./</audio>)"

    def get_supported_mm_modalities(self) -> List[str]:
        return ["image", "video", "audio"]

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None, "audio": None}

140
141
142
143
144
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
145
146
147
        return {
            "image": self.get_max_image_tokens(),
            "audio": self.get_max_audio_tokens(),
148
            "video": self.get_max_video_tokens(seq_len),
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        }

    def get_default_audio_pool_step(self) -> int:
        return 2

    def get_default_audio_sampling_rate(self) -> int:
        return 16000

    def get_chunk_length(self) -> int:
        return self.get_hf_config().audio_chunk_length

    def get_max_audio_tokens_per_chunk(self) -> int:
        pool_step = self.get_default_audio_pool_step()
        fbank_feat_in_chunk = 100
        cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
        num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
        return num_audio_tokens + 2  # <audio>(<unk>*N)</audio>

    def get_max_audio_chunks_with_most_features(self) -> int:
        return 30

    def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
        sampling_rate = self.get_default_audio_sampling_rate()
        # exclude <audio> </audio>
        num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
        return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1

    def get_num_frames_with_most_features(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.limit_per_prompt.get("image", 1)
        max_videos = mm_config.limit_per_prompt.get("video", 1)
        max_audios = mm_config.limit_per_prompt.get("audio", 1)

        # count <image_idx></image_idx> tokens
        # which are not in get_max_image_tokens
        max_image_tokens = self.get_max_image_tokens(
        ) * max_images + 4 * max_images
        max_audio_tokens = self.get_max_audio_tokens(
        ) * max_audios + 2 * max_audios
        max_total_frames = self.get_max_video_frames(seq_len -
                                                     max_image_tokens -
                                                     max_audio_tokens)

        num_frames = max(max_total_frames // max(max_videos, 1), 1)

        return num_frames


class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):

    def get_dummy_processor_inputs(
            self, seq_len: int, mm_counts: Mapping[str,
                                                   int]) -> ProcessorInputs:
        num_audios = mm_counts.get("audio", 0)
        audio_len = self.info.get_max_audio_chunks_with_most_features() * \
            self.info.get_default_audio_sampling_rate()

        processor_inputs = super().get_dummy_processor_inputs(
            seq_len, mm_counts)
        mm_data = {
            "image":
            processor_inputs.mm_data["image"],
            "video":
            processor_inputs.mm_data["video"],
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

        audio_prompt_texts = self.info.audio_pattern * num_audios

        return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
                               audio_prompt_texts,
                               mm_data=mm_data)


class MiniCPMOMultiModalProcessor(
        MiniCPMVMultiModalProcessor,
        BaseMultiModalProcessor[MiniCPMOProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        return MiniCPMOMultiModalDataParser(
            target_sr=self.info.get_default_audio_sampling_rate())

    def get_audio_prompt_texts(self,
                               audio_lens: int,
                               chunk_input: bool = True,
                               chunk_length: int = 1) -> str:
        return self.info.get_hf_processor().get_audio_placeholder(
            audio_lens, chunk_input, chunk_length)

    def get_special_tokens(self) -> Dict[str, torch.Tensor]:
        tokenizer = self.info.get_tokenizer()
        special_tokens = super().get_special_tokens()
        if hasattr(tokenizer, "audio_start_id"):
            special_tokens["audio_start_id"] = torch.tensor(
                tokenizer.audio_start_id)
            special_tokens["audio_end_id"] = torch.tensor(
                tokenizer.audio_end_id)
        return special_tokens

    def process_audios(self, mm_data: Mapping[str, object],
                       mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
        audios = mm_data.pop("audios", [])
        audio_embeds = mm_data.pop("audio_embeds", [])
        if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
            audio_outputs = {
                "audio_lens": [],
                "audio_features": [],
                "audio_feature_lens": [],
                "audio_num_segments": []
            }
            for audio in audios:
                single_audio_outputs = super().call_base_hf_processor(
                    prompt=self.info.audio_pattern,
                    mm_data={
                        "audios": audio,
                        "chunk_input": True
                    },
                    mm_kwargs=mm_kwargs)
                audio_outputs["audio_lens"].append(len(audio))
                audio_outputs["audio_features"].append(
                    single_audio_outputs["audio_features"])
                audio_outputs["audio_num_segments"].append(
                    len(single_audio_outputs["audio_feature_lens"][0]))
                audio_outputs["audio_feature_lens"] += \
                    single_audio_outputs["audio_feature_lens"]
            audio_outputs["audio_features"] = [
                audio_feature for single_audio_features in \
                    audio_outputs["audio_features"]
                for audio_feature in single_audio_features
            ]
            audio_outputs["audio_feature_lens"] = torch.cat(
                audio_outputs["audio_feature_lens"])
        elif len(audio_embeds):
            audio_outputs = {
                "audio_lens": [
                    self.info.get_audio_len_by_num_chunks(
                        sum(chunk_embeds.shape[0]
                            for chunk_embeds in single_audio_embeds))
                    for single_audio_embeds in audio_embeds
                ],
                "audio_embeds": [
                    chunk_embeds for single_audio_embeds in audio_embeds
                    for chunk_embeds in single_audio_embeds
                ],
                "audio_num_segments": [
                    len(single_audio_embeds)
                    for single_audio_embeds in audio_embeds
                ]
            }
        else:
            audio_outputs = {}
        return audio_outputs

    def get_placeholder_match_pattern(self) -> str:
        return r"\(<(image|video|audio)>./</\1>\)"

    def get_placeholder_split_pattern(self) -> str:
        return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"

    def process_mm_inputs(self, mm_data, mm_kwargs) -> object:
        return {
            "image": self.process_images(mm_data, mm_kwargs),
            "video": self.process_videos(mm_data, mm_kwargs),
            "audio": self.process_audios(mm_data, mm_kwargs)
        }

    def get_modality_num_counter(self, modality: str) -> str:
        if modality == "audio":
            return "audio_lens"
        return super().get_modality_num_counter(modality)

    def get_num_slices_by_modality(self, inputs: Dict[str, object],
                                   modality: str, index: int) -> int:
        if modality == "audio":
            return inputs["audio"]["audio_num_segments"][index]
        return super().get_num_slices_by_modality(inputs, modality, index)

    def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
                                     modality: str, index: int) -> str:
        if modality == "audio":
            return self.get_audio_prompt_texts(
                inputs["audio"]["audio_lens"][index])
        return super().get_prompt_texts_by_modality(inputs, modality, index)

    def _get_prompt_replacements(
            self, mm_items: MultiModalDataItems,
            hf_processor_mm_kwargs: Mapping[str, Any],
            out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
        placeholder = {
            "image": self.info.image_pattern,
            "video": self.info.video_pattern,
            "audio": self.info.audio_pattern
        }

        def get_replacement_minicpmv(item_idx: int, modality: str):
            if modality == "image":
                return self.get_image_prompt_texts(
                    mm_items["image"].get_image_size(item_idx), item_idx)
            elif modality == "video":
                return self.get_video_prompt_texts(
                    mm_items["video"].get_frame_size(item_idx),
                    mm_items["video"].get_num_frames(item_idx))
            else:  # audio
                if isinstance(mm_items["audio"], MiniCPMOAudioEmbeddingItems):
                    single_audio_embeds = mm_items["audio"].get(item_idx)
                    audio_len = self.info.get_audio_len_by_num_chunks(
                        sum(chunk_embeds.shape[0]
                            for chunk_embeds in single_audio_embeds))
                    return self.get_audio_prompt_texts(audio_len)
                return self.get_audio_prompt_texts(
                    len(mm_items["audio"].get(item_idx)))

        return [
            PromptReplacement(modality=modality,
                              target=placeholder[modality],
                              replacement=partial(get_replacement_minicpmv,
                                                  modality=modality))
            for modality in ("image", "video", "audio")
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
375
        audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
376
377
378

        return dict(
            **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
379
380
381
382
            audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", audio_num_slices),
            audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
                "audio", audio_num_slices),
383
384
            audio_num_slices=MultiModalFieldConfig.batched("audio"),
            audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
385
386
            audio_embeds=MultiModalFieldConfig.flat_from_sizes(
                "audio", audio_num_slices))
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811


class MultiModalProjector(nn.Module):

    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear1 = nn.Linear(in_features=in_dim,
                                 out_features=out_dim,
                                 bias=True)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(in_features=out_dim,
                                 out_features=out_dim,
                                 bias=True)

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.relu(self.linear1(audio_features))
        hidden_states = self.linear2(hidden_states)
        return hidden_states


class MiniCPMWhisperEncoderLayer(nn.Module):

    def __init__(self, config: WhisperConfig, layer_idx: int = None):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = WHISPER_ATTENTION_CLASSES[
            config._attn_implementation](
                embed_dim=self.embed_dim,
                num_heads=config.encoder_attention_heads,
                dropout=config.attention_dropout,
                config=config,
                layer_idx=layer_idx,
            )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        past_key_values = None
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states, attn_weights, past_key_values = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            past_key_value=past_key_values,
        )
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.activation_dropout,
                                              training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)
        hidden_states = residual + hidden_states

        if hidden_states.dtype == torch.float16 and (
                torch.isinf(hidden_states).any()
                or torch.isnan(hidden_states).any()):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states,
                                        min=-clamp_value,
                                        max=clamp_value)

        outputs = (hidden_states, )

        return outputs


class MiniCPMWhisperEncoder(WhisperEncoder):

    def __init__(self, config: WhisperConfig):
        super().__init__(config)
        self.layers = nn.ModuleList([
            MiniCPMWhisperEncoderLayer(config, layer_idx=i)
            for i in range(config.encoder_layers)
        ])

    def forward(
        self,
        input_features: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> BaseModelOutputWithPast:
        # Ignore copy
        input_features = input_features.to(dtype=self.conv1.weight.dtype,
                                           device=self.conv1.weight.device)

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)

        embed_pos = self.embed_positions.weight

        embed_pos = embed_pos[:inputs_embeds.shape[1], :]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

        encoder_states = ()

        for idx, encoder_layer in enumerate(self.layers):
            encoder_states = encoder_states + (hidden_states, )
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            # Ignore copy
            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                )

                hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        encoder_states = encoder_states + (hidden_states, )

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )


@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMOMultiModalProcessor,
    info=MiniCPMOProcessingInfo,
    dummy_inputs=MiniCPMODummyInputsBuilder)
class MiniCPMO(MiniCPMV2_6):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.apm = self.init_audio_module(vllm_config=vllm_config,
                                          prefix=maybe_prefix(prefix, "apm"))

    def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Do not use parameters temporarily
        audio_config = self.config.audio_config
        model = MiniCPMWhisperEncoder(audio_config)
        audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
        self.audio_avg_pooler = \
            nn.AvgPool1d(self.config.audio_pool_step,
                         stride=self.config.audio_pool_step)
        self.audio_projection_layer = \
            MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim)
        self.audio_encoder_layer = -1
        return model

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
        return loader.load_weights(weights)

    def subsequent_chunk_mask(
        self,
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = CPU_DEVICE,
        num_lookhead: int = 0,
    ) -> torch.Tensor:
        ret = torch.zeros(size, size, device=device, dtype=torch.bool)
        for i in range(size):
            if num_left_chunks < 0:
                start = 0
            else:
                start = max((i // chunk_size - num_left_chunks) * chunk_size,
                            0)
            ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
                         size)
            ret[i, start:ending] = True
        return ret

    def _get_feat_extract_output_lengths(self,
                                         input_lengths: torch.LongTensor):
        input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
        input_lengths_after_pooling = (
            input_lengths_after_cnn -
            self.config.audio_pool_step) // self.config.audio_pool_step + 1
        input_lengths_after_pooling = input_lengths_after_pooling.to(
            dtype=torch.int32)

        return input_lengths_after_cnn, input_lengths_after_pooling

    # Copied from HF repo of MiniCPM-o-2_6,
    # designed for batched inputs and outputs
    def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
                                chunk_length: int) -> torch.Tensor:
        wavforms = data.get(
            "data",
            [])  # (bs, 80, frames) or [], multi audios need filled in advance
        audio_feature_lens_raw = [data.get("audio_feature_lens",
                                           [])]  # list, [[x1, x2], [y1], [z1]]

        # exist audio
        if len(wavforms) > 0:
            audio_feature_lens = torch.hstack(audio_feature_lens_raw)
            batch_size, _, max_mel_seq_len = wavforms.shape
            max_seq_len = (max_mel_seq_len - 1) // 2 + 1

            # Create a sequence tensor of shape (batch_size, max_seq_len)
            seq_range = (torch.arange(
                0,
                max_seq_len,
                dtype=audio_feature_lens.dtype,
                device=audio_feature_lens.device).unsqueeze(0).expand(
                    batch_size, max_seq_len))
            lengths_expand = audio_feature_lens.unsqueeze(1).expand(
                batch_size, max_seq_len)
            # Create mask
            padding_mask = seq_range >= lengths_expand  # 1 for padded values

            audio_attention_mask_ = padding_mask.view(
                batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
                                                      max_seq_len, max_seq_len)
            audio_attention_mask = audio_attention_mask_.to(
                dtype=self.apm.conv1.weight.dtype,
                device=self.apm.conv1.weight.device)

            if chunk_length > 0:
                chunk_num_frame = int(chunk_length * 50)
                chunk_mask = self.subsequent_chunk_mask(
                    size=max_seq_len,
                    chunk_size=chunk_num_frame,
                    num_left_chunks=-1,
                    device=audio_attention_mask_.device,
                )
                audio_attention_mask_ = torch.logical_or(
                    audio_attention_mask_, torch.logical_not(chunk_mask))

            audio_attention_mask[audio_attention_mask_] = float("-inf")
            audio_states = self.apm(
                wavforms, attention_mask=audio_attention_mask).hidden_states[
                    self.audio_encoder_layer]
            audio_embeds = self.audio_projection_layer(audio_states)

            audio_embeds = audio_embeds.transpose(1, 2)
            audio_embeds = self.audio_avg_pooler(audio_embeds)
            audio_embeds = audio_embeds.transpose(1, 2)

            _, feature_lens_after_pooling = \
                self._get_feat_extract_output_lengths(audio_feature_lens)

            num_audio_tokens = feature_lens_after_pooling

            final_audio_embeds = []
            idx = 0
            for i in range(len(audio_feature_lens_raw)):
                target_audio_embeds = []
                for _ in range(len(audio_feature_lens_raw[i])):
                    target_audio_embeds.append(
                        audio_embeds[idx, :num_audio_tokens[idx], :])
                    idx += 1
                final_audio_embeds.append(target_audio_embeds)
            return final_audio_embeds
        else:
            return []

    def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
                                  audio_inputs: Optional[MiniCPMOAudioInputs],
                                  chunk_length: int) -> torch.Tensor:
        device, dtype = vlm_embedding.device, vlm_embedding.dtype
        if audio_inputs["type"] == "audio_embeds":
            audio_embeddings = audio_inputs["data"]
            audio_embeddings = [
                audio_embeddings[i].to(device=device, dtype=dtype)
                for i in range(len(audio_embeddings))
            ]
        else:
            audio_embeddings = self.get_audio_hidden_states(
                audio_inputs, chunk_length)[0]
        if audio_embeddings is None or len(audio_embeddings) == 0:
            return vlm_embedding
        audio_bounds = audio_inputs["audio_bounds"]
        if self.config.chunk_input:
            audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
                                                               dtype=dtype)
            audio_start_pos = 0
            for bound in audio_bounds:
                audio_len = bound[1] - bound[0]
                vlm_embedding[bound[0]:bound[1]] = audio_embs[
                    audio_start_pos:audio_start_pos + audio_len, :]
                audio_start_pos += audio_len
        else:
            for embs, bound in zip(audio_embeddings, audio_bounds):
                audio_indices = torch.arange(bound[0],
                                             bound[1],
                                             dtype=torch.long).to(device)

                if embs.shape[0] != len(audio_indices):
                    raise ValueError(
                        "Shape mismatch: Trying to assign embeddings "
                        f"of shape {embs.shape} "
                        f"to input indices of length {len(audio_indices)}")
                vlm_embedding[audio_indices] = embs.to(dtype)
        return vlm_embedding

    def _get_audio_bounds(self, input_ids: torch.Tensor,
                          audio_start_id: torch.Tensor,
                          audio_end_id: torch.Tensor) -> torch.Tensor:
        audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
        audio_start_tokens += 1
        audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
        valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
        return torch.hstack([
            audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
            audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
        ])

    def _parse_and_validate_audio_inputs(
            self, input_ids: torch.Tensor,
            **kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
        audio_features = kwargs.pop("audio_features", [])
        audio_feature_lens = kwargs.pop("audio_feature_lens", [])
        audio_embeds = kwargs.pop("audio_embeds", None)
        audio_start_id = kwargs.pop("audio_start_id", None)
        audio_end_id = kwargs.pop("audio_end_id", None)
        if audio_embeds is not None:
            audio_embeds = [
                audio_embeds[i][j] for i in range(len(audio_embeds))
                for j in range(len(audio_embeds[i]))
            ]
            return MiniCPMOAudioEmbeddingInputs(
                audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
                                                    audio_end_id),
                data=audio_embeds,
                type="audio_embeds")
        if len(audio_features) > 0:
            audio_features_all = [
                i.permute(1, 0) for audio_feature in audio_features
                for i in audio_feature
            ]
            audio_features = torch.nn.utils.rnn.pad_sequence(
                audio_features_all, batch_first=True,
                padding_value=0.0).permute(0, 2, 1)
            audio_feature_lens = torch.cat(
                [item for item in audio_feature_lens])

            return MiniCPMOAudioFeatureInputs(
                audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
                                                    audio_end_id),
                data=audio_features,
                audio_feature_lens=audio_feature_lens,
                type="audio_features")
        return None

    def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
                                   **kwargs: object):
        image_inputs = self._parse_and_validate_image_inputs(
            input_ids, **kwargs)
        if not any("audio" in key for key in kwargs):
            return image_inputs, None
        audio_inputs = self._parse_and_validate_audio_inputs(
            input_ids, **kwargs)
        return image_inputs, audio_inputs

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            vlm_embeddings = None
        else:
            image_inputs, audio_inputs = \
                self._parse_and_validate_inputs(input_ids, **kwargs)
            vlm_embeddings, _ = self.get_embedding_with_vision(
                input_ids, image_inputs)

            if audio_inputs is not None:
                vlm_embeddings = self.get_embedding_with_audios(
                    vlm_embeddings, audio_inputs,
                    self.config.audio_chunk_length)

        # always pass the input via `inputs_embeds`
        # to make sure the computation graph is consistent
        # for `torch.compile` integration
        input_ids = None

        output = self.llm.model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=vlm_embeddings,
        )
        return output