midashenglm.py 28.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Horizon team, Xiaomi MiLM Plus.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25

26
27
28
29
30
31
32
33
import collections
import collections.abc
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TypedDict, Union, cast

import numpy as np
import torch
import torch.nn as nn
34
35
import torchaudio.functional as F
from torch.nn.functional import scaled_dot_product_attention
36
37
38
from transformers import BatchFeature

from vllm.config import VllmConfig
39
from vllm.config.multimodal import BaseDummyOptions
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
58
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

_Tuple2 = Union[int, tuple[int, int], Sequence[int]]


def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]:
    if isinstance(x, collections.abc.Sequence):
        assert len(x) == 2, (
            f"Expected a sequence of length 2, got {x} with length {len(x)}")
        return cast(tuple[int, int], tuple(x))
    return (x, x)


def calculate_mel_frames_dasheng(
    audio_length_samples: int,
    n_fft: int = 512,
    hop_size: int = 160,
    dasheng_subsampling: int = 4,
    center=True,
    model_subsampling: int = 5,
) -> int:
    """Calculate the number of Mel-spectrogram frames."""
    if center:
        audio_length_samples = audio_length_samples + n_fft

    return (int(1 + ((audio_length_samples - n_fft) / hop_size)) //
            dasheng_subsampling // model_subsampling)


class AudioPatchEmbed(nn.Module):

    def __init__(
        self,
        input_size: _Tuple2 = 64,
        patch_size: _Tuple2 = 16,
        patch_stride: _Tuple2 = 16,
        in_chans: int = 1,
        embed_dim: int = 768,
        norm_layer: Optional[Callable] = None,
        flatten: bool = False,
    ):
        super().__init__()
        self.input_size = _resolve_tuple2(input_size)
        self.patch_size = _resolve_tuple2(patch_size)
        self.patch_stride = _resolve_tuple2(patch_stride)
        self.grid_size = (
            self.input_size[0] // self.patch_stride[0],
            self.input_size[1] // self.patch_stride[1],
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_stride,
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        if self.flatten:
            x = torch.permute(torch.flatten(
                x, 2, 3), (0, 2, 1))  # rearrange(x, "b c f t -> b (f t) c")
        x = self.norm(x)
        return x


class LayerScale(nn.Module):

    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class DashengMlp(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
151
152
153
154
155
156
        self.fc1 = ColumnParallelLinear(
            input_size=in_features,
            output_size=hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
157
        self.act = get_act_fn("gelu")
158
159
160
161
162
163
        self.fc2 = RowParallelLinear(
            input_size=hidden_features,
            output_size=out_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.act(x)
        x, _ = self.fc2(x)
        return x


class DashengAttention(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.embed_dim = dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5

        self.qkv = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        B, N, C = x.shape

222
223
224
225
        qkv, _ = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
226

227
228
229
230
231
232
        x = scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=mask[:, None, None, :] if mask is not None else None,
        )
233

234
235
        x = x.transpose(1, 2).reshape(B, N, C)
        x, _ = self.proj(x)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        return x


class DashengBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        init_values: Optional[float] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = DashengAttention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.ls1 = (LayerScale(dim, init_values=init_values)
                    if init_values else nn.Identity())

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = DashengMlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.ls2 = (LayerScale(dim, init_values=init_values)
                    if init_values else nn.Identity())

    # Kwargs usually has a mask parameter that is passed to Attention
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = x + self.ls1(self.attn(self.norm1(x), mask))
        x = x + self.ls2(self.mlp(self.norm2(x)))
        return x


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class DashengFrontend(nn.Module):

    def __init__(self, config: DashengConfig):
        super().__init__()
        self.config = config

        spectrogram_window = torch.hann_window(self.config.win_length)
        self.register_buffer(
            "spectrogram_window",
            spectrogram_window,
            persistent=False,
        )
        self.spectrogram_window: torch.Tensor

        melscale_fbanks = F.melscale_fbanks(
            n_freqs=self.config.n_fft // 2 + 1,
            f_min=self.config.f_min,
            f_max=self.config.f_max,
            n_mels=self.config.n_mels,
            sample_rate=self.config.sample_rate,
        )
        self.register_buffer("melscale_fbanks",
                             melscale_fbanks,
                             persistent=False)
        self.melscale_fbanks: torch.Tensor

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        spectrogram = F.spectrogram(
            waveform=waveform.to(torch.float32),
            pad=0,
            window=self.spectrogram_window,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.win_length,
            power=2,
            normalized=False,
            center=self.config.center,
        )
        mel_spectrogram = (
            spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
        # x has shape [batch, freq, time].
        # F.amplitude_to_DB accepts inputs shaped as:
        #   - [freq, time]
        #   - [channel, freq, time]
        #   - [..., channel, freq, time]
        # Here we insert a channel dimension of size 1 before calling it,
        # then remove that extra dimension afterward.
        log_mel_spectrogram = F.amplitude_to_DB(
            mel_spectrogram.unsqueeze(1),
            multiplier=10,
            amin=1e-10,
            db_multiplier=0,
            top_db=120,
        ).squeeze(1)
        return log_mel_spectrogram.to(waveform.dtype)


341
342
343
344
345
346
347
348
349
350
351
352
353
class DashengAudioTransformer(nn.Module):

    def __init__(
        self,
        config: DashengConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.target_length = config.target_length
        self.hop_length = config.hop_length

354
        self.front_end = DashengFrontend(config)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

        self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)

        self.patch_embed = AudioPatchEmbed(
            input_size=(config.n_mels, config.target_length),
            embed_dim=config.embed_dim,
            in_chans=config.input_channels,
            patch_size=config.patch_size,
            flatten=False,
            patch_stride=config.patch_stride,
        )

        self.time_pos_embed = nn.Parameter(
            torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1]))
        self.freq_pos_embed = nn.Parameter(
            torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1))
        self.blocks = nn.ModuleList(
            DashengBlock(
                dim=config.embed_dim,
                num_heads=config.num_heads,
                mlp_ratio=config.mlp_ratio,
                qkv_bias=config.qkv_bias,
                init_values=config.init_values,
                quant_config=quant_config,
379
                prefix=f"{prefix}.blocks.{i}",
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            ) for i in range(config.depth))
        self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)

    def forward_features(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        t = x.shape[-1]
        x = x + self.time_pos_embed[:, :, :, :t]
        x = (x + self.freq_pos_embed[:, :, :, :]
             )  # Just to support __getitem__ in posembed
        x = torch.permute(torch.flatten(x, 2, 3),
                          (0, 2, 1))  # rearrange(x, "b c f t -> b (f t) c")
        for block in self.blocks:
            x = block(x, mask)
        x = self.norm(x)
        return x

    def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor:
        batch_size = len(lengths)
        idx = torch.arange(max_length, device=lengths.device)
        idx = idx.repeat(batch_size).view(batch_size, max_length)
        mask = (idx < lengths.unsqueeze(-1)).bool()
        return mask

    def forward(
        self,
        x: torch.Tensor,
        x_length: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        x = self.front_end(x)
        x = x.to(self.time_pos_embed.dtype)
        target_length_in_patches = self.target_length // 4
        x = x.unsqueeze(1)
        x = torch.permute(x, (0, 2, 1, 3))
        x = self.init_bn(x)
        x = torch.permute(x, (0, 2, 1, 3))

        x = self.patch_embed(x)
        t = x.shape[-1]

        input_splits = x.split(target_length_in_patches, dim=-1)

        if x_length is not None:
            assert len(x_length) == len(x), (
                "batchsizes of input x and x_length need to be same")
            assert x_length.ndim == 1, "Lengths are of size (B,)"
            scaled_lengths = (x_length / (self.hop_length * 4)).long()
            mask = self._to_mask(max_length=t, lengths=scaled_lengths)
430
            split_masks = mask.split(target_length_in_patches, dim=-1)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        else:
            mask = None
            split_masks = [None] * len(input_splits)

        outputs = []

        for split_x, split_mask in zip(input_splits, split_masks):
            forward_kwargs = {}
            forward_kwargs["mask"] = split_mask
            split_x = self.forward_features(split_x, **forward_kwargs)
            outputs.append(split_x)
        x = torch.cat(outputs, dim=1)
        return x, mask


class AudioProjectorSubsample(nn.Module):

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        downsample_rate=5,
        dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.k = downsample_rate
        self.net = nn.Sequential(
            ColumnParallelLinear(
                input_size=in_dim * self.k,
                output_size=out_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.net.0",
                return_bias=False,
466
467
            ),
            get_act_fn("gelu"),
468
469
470
471
472
473
            RowParallelLinear(
                input_size=out_dim,
                output_size=out_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.net.2",
                return_bias=False,
474
475
            ),
        )
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532

    def forward(self, x, mask=None):
        batch_size, seq_len, dim = x.shape
        num_frames_to_discard = seq_len % self.k
        if num_frames_to_discard > 0:
            x = x[:, :-num_frames_to_discard, :]
            if mask is not None:
                mask = mask[:, :-num_frames_to_discard]
        if mask is None:
            mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device)
        x = x.reshape(batch_size, -1, self.k *
                      dim)  # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
        for layer in self.net:
            x = layer(x)
        mask = mask.reshape(
            batch_size, -1,
            self.k)  # rearrange(mask, "b (s k) -> b s k", k=self.k)
        mask = mask.any(dim=-1).long()
        return x, mask


# === Audio Inputs === #
class MiDashengLMAudioInputs(TypedDict):
    input_values: torch.Tensor
    """Shape: `(num_audios, num_sampling_points)`"""
    audio_length: torch.Tensor
    """Shape: `(num_audios, 1)`"""


class MiDashengLMProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_feature_extractor(self):
        hf_processor = self.get_hf_processor()
        feature_extractor = hf_processor.feature_extractor
        return feature_extractor

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}

    def get_min_audio_len(self):
        return 3200

    def get_max_audio_len(self):
        return 160000


class MiDashengLMDummyInputsBuilder(
        BaseDummyInputsBuilder[MiDashengLMProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token
533
534
        audio_bos_token = hf_processor.audio_bos_token
        audio_eos_token = hf_processor.audio_eos_token
535

536
537
        single_audio_text = f"{audio_bos_token}{audio_token}{audio_eos_token}"
        return single_audio_text * num_audios
538
539
540
541
542

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
543
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
544
545
546
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)

547
548
        audio_overrides = mm_options.get("audio") if mm_options else None

549
550
551
        return {
            "audio":
            self._get_dummy_audios(length=self.info.get_max_audio_len(),
552
553
                                   num_audios=num_audios,
                                   overrides=audio_overrides)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        }


class MiDashengLMMultiModalProcessor(
        BaseMultiModalProcessor[MiDashengLMProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, Any],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        audios = mm_data.pop("audios", [])

        # + Padding
        min_audio_len = self.info.get_min_audio_len()
        processed_audios = [
576
577
578
579
580
581
            np.pad(
                audio,
                (0, min_audio_len - audio.shape[-1]),
                mode="constant",
                constant_values=0,
            ) if isinstance(audio, np.ndarray)
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            and audio.shape[-1] < min_audio_len else audio for audio in audios
        ]

        if processed_audios:
            mm_data["audio"] = processed_audios

        if not mm_data.get("audio", []):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_kwargs = dict(**mm_kwargs, )

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            input_values=MultiModalFieldConfig.batched("audio"),
            audio_length=MultiModalFieldConfig.batched("audio"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
        audio_token_id = vocab[audio_token]

        out_mm_data = out_mm_kwargs.get_data()
        audio_length = out_mm_data.get("audio_length")
        if audio_length is None:
            audio_output_lengths = []
        else:
630
631
            audio_length_np = (audio_length.cpu().numpy() if isinstance(
                audio_length, torch.Tensor) else audio_length)
632
633
634
635
636
637
638
639
640
641
642
            audio_output_lengths = [
                max(1, calculate_mel_frames_dasheng(
                    int(length)))  # at least one frame
                for length in audio_length_np
            ]

        def get_replacement_midashenglm(item_idx: int):
            num_features = audio_output_lengths[item_idx]
            audio_tokens = [audio_token_id] * num_features

            return PromptUpdateDetails.select_token_id(
643
                audio_tokens,
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
                embed_token_id=audio_token_id,
            )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_midashenglm,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
    MiDashengLMMultiModalProcessor,
    info=MiDashengLMProcessingInfo,
    dummy_inputs=MiDashengLMDummyInputsBuilder,
)
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
662
663
664
665
666
667
668
669
670
671
672
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return "<|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config

        # Initialize audio components
        self.audio_encoder = DashengAudioTransformer(
            config.audio_encoder_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "audio_encoder"),
        )
        self.audio_projector = AudioProjectorSubsample(
            in_dim=config.audio_encoder_config.embed_dim,
            out_dim=config.text_config.hidden_size,
            downsample_rate=config.subsample_factor,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "audio_projector"),
        )

        # Initialize language model (decoder)
        self.decoder = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "decoder"),
            architectures=["Qwen2ForCausalLM"],
        )

        self.quant_config = quant_config
        self.make_empty_intermediate_tensors = (
            self.decoder.make_empty_intermediate_tensors)

    def _validate_and_reshape_mm_tensor(self, mm_input: object,
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
716
717
            raise ValueError(
                f"Incorrect type of {name}. Got type: {type(mm_input)}")
718
        if isinstance(mm_input, torch.Tensor):
719
            return mm_input.reshape(-1, *mm_input.shape[2:])
720
721
722
723
724
725
726
727
728
729
730
731

        if name == "input_values":
            max_length = max(tensor.shape[1] for tensor in mm_input)
            padded_mm_input = [
                torch.nn.functional.pad(tensor,
                                        (0, max_length - tensor.shape[1]))
                if tensor.shape[1] < max_length else tensor
                for tensor in mm_input
            ]
            return torch.concat(padded_mm_input)

        return torch.concat(mm_input)
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]:
        input_values = kwargs.pop("input_values", None)
        audio_length = kwargs.pop("audio_length", None)

        if input_values is None:
            return None
        input_values = self._validate_and_reshape_mm_tensor(
            input_values, "input_values")
        audio_length = self._validate_and_reshape_mm_tensor(
            audio_length, "audio_length")
        if not isinstance(input_values, (torch.Tensor, list)):
            raise ValueError("Incorrect type of audio input features. "
                             f"Got type: {type(input_values)}")

        return MiDashengLMAudioInputs(
            input_values=input_values,
            audio_length=audio_length,
        )

    def _process_audio_input(
            self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
        # Process audio through encoder and projector
        input_values = audio_input["input_values"]
        audio_length = audio_input["audio_length"]

        encoder_out, encoder_atts = self.audio_encoder(input_values,
                                                       audio_length)
        audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts)
        audio_embeddings = audio_embeddings.to(
            audio_input["input_values"].dtype)
        batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape

766
767
        audio_length_np = (audio_length.cpu().numpy() if isinstance(
            audio_length, torch.Tensor) else audio_length)
768
769
770
771
772
773
774
775
        audio_output_lengths = [
            max(1, calculate_mel_frames_dasheng(
                int(length)))  # at least one frame
            for length in audio_length_np
        ]
        audio_output_lengths = torch.tensor(audio_output_lengths).to(
            audio_embeddings.device)

776
        audio_feature_mask = torch.arange(
777
778
            max_audio_tokens,
            device=audio_embeddings.device).unsqueeze(0).expand(
779
780
                batch_size,
                max_audio_tokens) < audio_output_lengths.unsqueeze(1)
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810

        masked_audio_features = audio_embeddings[audio_feature_mask].view(
            -1, embed_dim)

        return torch.split(masked_audio_features,
                           audio_output_lengths.tolist())

    def get_language_model(self) -> torch.nn.Module:
        return self.decoder

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        audio_input = self._parse_and_validate_audio_input(**kwargs)

        if audio_input is None:
            return []
        return self._process_audio_input(audio_input)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
811
812
813
814
815
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                multimodal_embeddings,
                is_multimodal=input_ids == self.config.audio_token_id,
            )
816
817
            input_ids = None

818
819
820
821
822
823
        return self.decoder.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
824
825
826
827
828

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
829
        return self.decoder.compute_logits(hidden_states)
830
831
832
833
834

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)