"docs/vscode:/vscode.git/clone" did not exist on "30fe765e9facf8892bc0261a6ae234c286b7d802"
funasr.py 32.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
    BatchFeature,
    Qwen3Config,
)

from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
    MMEncoderAttention,
)
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseDummyInputsBuilder,
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsTranscription,
    _require_is_multimodal,
)
from .qwen3 import Qwen3Model
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    maybe_prefix,
)

logger = init_logger(__name__)


def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
    if maxlen is None:
        maxlen = lengths.max()
    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
    matrix = torch.unsqueeze(lengths, dim=-1)
    mask = row_vector < matrix
    mask = mask.detach()

    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)


class LayerNorm(torch.nn.LayerNorm):
    def __init__(self, nout, dim=-1):
        super().__init__(nout, eps=1e-12)
        self.dim = dim

    def forward(self, x: torch.Tensor):
        if self.dim == -1:
            return super().forward(x)
        return super().forward(x.transpose(self.dim, -1)).transpose(self.dim, -1)


class EncoderLayerSANM(nn.Module):
    def __init__(
        self,
        in_size: int,
        size: int,
        self_attn: nn.Module,
        feed_forward: nn.Module,
        normalize_before=True,
    ):
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(in_size)
        self.norm2 = LayerNorm(size)
        self.in_size = in_size
        self.size = size
        self.normalize_before = normalize_before

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: torch.Tensor | None = None,
        cache=None,
Jiayi Yan's avatar
Jiayi Yan committed
118
        mask_shift_chunk=None,
119
120
121
122
123
124
125
126
127
        mask_att_chunk_encoder=None,
    ):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)

        if self.in_size == self.size:
            hidden_states = residual + self.self_attn(
                hidden_states,
                mask,
Jiayi Yan's avatar
Jiayi Yan committed
128
                mask_shift_chunk=mask_shift_chunk,
129
130
131
132
133
134
                mask_att_chunk_encoder=mask_att_chunk_encoder,
            )
        else:
            hidden_states = self.self_attn(
                hidden_states,
                mask,
Jiayi Yan's avatar
Jiayi Yan committed
135
                mask_shift_chunk=mask_shift_chunk,
136
137
138
139
140
141
142
                mask_att_chunk_encoder=mask_att_chunk_encoder,
            )

        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = residual + self.feed_forward(hidden_states)

Jiayi Yan's avatar
Jiayi Yan committed
143
        return hidden_states, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185


class MultiHeadedAttentionSANM(nn.Module):
    def __init__(
        self,
        n_head: int,
        in_feat: int,
        n_feat: int,
        kernel_size: int,
        sanm_shift: int = 0,
    ):
        super().__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.h = n_head
        self.out_proj = ReplicatedLinear(
            input_size=n_feat,
            output_size=n_feat,
            bias=True,
        )
        self.linear_q_k_v = ReplicatedLinear(
            input_size=in_feat,
            output_size=n_feat * 3,
            bias=True,
        )
        self.attn = None

        self.fsmn_block = nn.Conv1d(
            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
        )
        # padding
        left_padding = (kernel_size - 1) // 2
        if sanm_shift > 0:
            left_padding = left_padding + sanm_shift
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)

    def forward_fsmn(
        self,
        inputs: torch.Tensor,
        mask: torch.Tensor,
Jiayi Yan's avatar
Jiayi Yan committed
186
        mask_shift_chunk: torch.Tensor = None,
187
188
189
190
    ):
        b, t, d = inputs.size()
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
Jiayi Yan's avatar
Jiayi Yan committed
191
192
            if mask_shift_chunk is not None:
                mask = mask * mask_shift_chunk
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            inputs = inputs * mask

        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x += inputs
        if mask is not None:
            x = x * mask
        return x

    def forward_qkv(self, x: torch.Tensor):
        b, t, d = x.size()

        q_k_v, _ = self.linear_q_k_v(x)
        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)
        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)
        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)

        return q_h, k_h, v_h, v

    def forward_attention(
        self,
        value: torch.Tensor,
        scores: torch.Tensor,
        mask: torch.Tensor,
        mask_att_chunk_encoder: torch.Tensor = None,
    ):
        n_batch = value.size(0)
        if mask is not None:
            if mask_att_chunk_encoder is not None:
                mask = mask * mask_att_chunk_encoder

            mask = mask.unsqueeze(1).eq(0)

            min_value = -float("inf")
            scores = scores.masked_fill(mask, min_value)
            attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        else:
            attn = torch.softmax(scores, dim=-1)

        p_attn = attn
        x = torch.matmul(p_attn, value)
        x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)

        out, _ = self.out_proj(x)
        return out

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: torch.Tensor,
Jiayi Yan's avatar
Jiayi Yan committed
246
        mask_shift_chunk: torch.Tensor = None,
247
248
249
        mask_att_chunk_encoder: torch.Tensor = None,
    ):
        q_h, k_h, v_h, v = self.forward_qkv(hidden_states)
Jiayi Yan's avatar
Jiayi Yan committed
250
        fsmn_memory = self.forward_fsmn(v, mask, mask_shift_chunk)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        q_h = q_h * self.d_k ** (-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
        return att_outs + fsmn_memory


class SinusoidalPositionEncoder(torch.nn.Module):
    def __init__(self, d_model=80):
        super().__init__()

    def encode(
        self,
        positions: torch.Tensor = None,
        depth: int = None,
        dtype: torch.dtype = torch.float32,
    ):
        batch_size = positions.size(0)
        positions = positions.type(dtype)
        device = positions.device
        log_timescale_increment = torch.log(
            torch.tensor([10000], dtype=dtype, device=device)
        ) / (depth / 2 - 1)
        inv_timescales = torch.exp(
            torch.arange(depth / 2, device=device).type(dtype)
            * (-log_timescale_increment)
        )
        inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
            inv_timescales, [1, 1, -1]
        )
        encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
        return encoding.type(dtype)

    def forward(self, hidden_states: torch.Tensor):
        batch_size, timesteps, input_dim = hidden_states.size()
        positions = torch.arange(1, timesteps + 1, device=hidden_states.device)[None, :]
        position_encoding = self.encode(positions, input_dim, hidden_states.dtype).to(
            hidden_states.device
        )

        return hidden_states + position_encoding


class SenseVoiceEncoderSmall(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        tp_blocks: int = 0,
        attention_dropout_rate: float = 0.0,
        normalize_before: bool = True,
        kernel_size: int = 11,
        sanm_shift: int = 0,
        **kwargs,
    ):
        super().__init__()
        self._output_size = output_size
        self.embed = SinusoidalPositionEncoder()

        self.normalize_before = normalize_before

        positionwise_layer = PositionwiseFeedForward
        positionwise_layer_args = (
            output_size,
            linear_units,
        )

        encoder_selfattn_layer = MultiHeadedAttentionSANM
        encoder_selfattn_layer_args0 = (
            attention_heads,
            input_size,
            output_size,
            kernel_size,
            sanm_shift,
        )
        encoder_selfattn_layer_args = (
            attention_heads,
            output_size,
            output_size,
            kernel_size,
            sanm_shift,
        )

        self.encoders0 = nn.ModuleList(
            [
                EncoderLayerSANM(
                    input_size,
                    output_size,
                    encoder_selfattn_layer(*encoder_selfattn_layer_args0),
                    positionwise_layer(*positionwise_layer_args),
                )
                for i in range(1)
            ]
        )
        self.encoders = nn.ModuleList(
            [
                EncoderLayerSANM(
                    output_size,
                    output_size,
                    encoder_selfattn_layer(*encoder_selfattn_layer_args),
                    positionwise_layer(*positionwise_layer_args),
                )
                for i in range(num_blocks - 1)
            ]
        )

        self.tp_encoders = nn.ModuleList(
            [
                EncoderLayerSANM(
                    output_size,
                    output_size,
                    encoder_selfattn_layer(*encoder_selfattn_layer_args),
                    positionwise_layer(*positionwise_layer_args),
                )
                for i in range(tp_blocks)
            ]
        )

        self.after_norm = LayerNorm(output_size)

        self.tp_norm = LayerNorm(output_size)

    def output_size(self) -> int:
        return self._output_size

    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
    ):
        maxlen = xs_pad.shape[1]
        masks = sequence_mask(
            ilens, maxlen=maxlen, dtype=ilens.dtype, device=ilens.device
        )[:, None, :]

        xs_pad *= self.output_size() ** 0.5

        xs_pad = self.embed(xs_pad)

        for layer_idx, encoder_layer in enumerate(self.encoders0):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        for layer_idx, encoder_layer in enumerate(self.encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        xs_pad = self.after_norm(xs_pad)

        olens = masks.squeeze(1).sum(1).int()

        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        xs_pad = self.tp_norm(xs_pad)
        return xs_pad, olens


class PositionwiseFeedForward(nn.Module):
    def __init__(self, idim: int, hidden_units: int):
        super().__init__()
        self.w_1 = ColumnParallelLinear(
            input_size=idim,
            output_size=hidden_units,
            bias=True,
        )
        self.w_2 = RowParallelLinear(
            input_size=hidden_units,
            output_size=idim,
            bias=True,
        )
        self.activation = _ACTIVATION_REGISTRY["relu"]

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.w_1(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states, _ = self.w_2(hidden_states)
        return hidden_states


class EncoderLayer(nn.Module):
    def __init__(
        self,
        size: int,
        self_attn: nn.Module,
        feed_forward: nn.Module,
    ):
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states = residual + self.self_attn(hidden_states, None, None)
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = residual + self.feed_forward(hidden_states)

        return hidden_states


class FunASRAudioAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.num_local_heads = self.num_heads // tp_size

        if (self.head_dim * self.num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
                f"{self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        self.scaling = self.head_dim**-0.5

        self.qkv = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            total_num_kv_heads=self.num_heads,
            bias=True,
            prefix=f"{prefix}.qkv",
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            bias=True,
            prefix=f"{prefix}.out_proj",
        )

        self.attn = MMEncoderAttention(
            num_heads=self.num_local_heads,
            head_size=self.head_dim,
            scale=self.scaling,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: torch.Tensor | None,
    ) -> torch.Tensor:
        bs, seq_length, _ = hidden_states.size()
        qkv, _ = self.qkv(hidden_states)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(bs, seq_length, -1, self.head_dim)
        k = k.view(bs, seq_length, -1, self.head_dim)
        v = v.view(bs, seq_length, -1, self.head_dim)

        attn_output = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )

        attn_output = attn_output.view(bs, seq_length, -1)
        output, _ = self.out_proj(attn_output)
        return output


class Transformer(nn.Module):
    def __init__(
        self,
        downsample_rate=2,
        encoder_dim=1280,
        llm_dim=4096,
        ffn_dim: int = 2048,
        prefix: str = "",
        **kwargs,
    ):
        super().__init__()
        self.k = downsample_rate
        self.encoder_dim = encoder_dim
        self.llm_dim = llm_dim
        self.linear1 = ColumnParallelLinear(
            input_size=self.encoder_dim * self.k,
            output_size=ffn_dim,
            bias=True,
        )
        self.relu = nn.ReLU()
        self.linear2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=self.llm_dim,
            bias=True,
        )

        self.blocks = None
        if kwargs.get("n_layer", 2) > 0:
            self.blocks = nn.ModuleList(
                [
                    EncoderLayer(
                        llm_dim,
                        FunASRAudioAttention(
                            kwargs.get("attention_heads", 8),
                            llm_dim,
                            prefix=f"{prefix}.self_attn",
                        ),
                        PositionwiseFeedForward(
                            llm_dim,
                            llm_dim // 4,
                        ),
                    )
                    for _ in range(kwargs.get("n_layer", 2))
                ]
            )

    def forward(self, hidden_states: torch.Tensor, ilens: int = 0):
        batch_size, seq_len, dim = hidden_states.size()
        chunk_num = (seq_len - 1) // self.k + 1
        pad_num = chunk_num * self.k - seq_len
        hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num, 0, 0), value=0.0)
        seq_len = hidden_states.size(1)

        hidden_states = hidden_states.contiguous()
        hidden_states = hidden_states.view(batch_size, chunk_num, dim * self.k)
        hidden_states, _ = self.linear1(hidden_states)
        hidden_states = self.relu(hidden_states)
        hidden_states, _ = self.linear2(hidden_states)

        olens = None
        olens = (ilens - 1) // self.k + 1

        if self.blocks is not None:
            for layer, block in enumerate(self.blocks):
                hidden_states = block(hidden_states)
        return hidden_states, olens


class FunASRAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

    input_features: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b", "nmb", "t"),
    ]
    speech_lengths: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b"),
    ]
AllenDou's avatar
AllenDou committed
613
614
615
616
    fake_token_lengths: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b"),
    ]
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727


class FunASREncoder(nn.Module):
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
        super().__init__()
        self.audio_encoder = SenseVoiceEncoderSmall(
            input_size=560, **vllm_config.model_config.hf_config.audio_encoder_conf
        )
        self.audio_adaptor = Transformer(
            downsample_rate=1,
            use_low_frame_rate=True,
            ffn_dim=2048,
            llm_dim=1024,
            encoder_dim=512,
            n_layer=2,
            freeze=True,
            prefix=maybe_prefix(prefix, "audio_encoder"),
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights with mapping from HuggingFace format."""
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("self_attn.qkv.", "self_attn.q_proj.", "q"),
            ("self_attn.qkv.", "self_attn.k_proj.", "k"),
            ("self_attn.qkv.", "self_attn.v_proj.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict.get(name)
                if param is not None:
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class FunASRModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.encoder = FunASREncoder(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "encoder")
        )
        self.decoder = Qwen3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
        speech: torch.Tensor | list[torch.Tensor] | None,
        speech_lengths: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
        self.feat_permute = False

        if self.feat_permute:
            encoder_out, encoder_out_lens = self.encoder.audio_encoder(
                speech.permute(0, 2, 1), speech_lengths
            )
        else:
            encoder_out, encoder_out_lens = self.encoder.audio_encoder(
                speech, speech_lengths
            )

        encoder_out, encoder_out_lens = self.encoder.audio_adaptor(
            encoder_out, encoder_out_lens
        )
        return encoder_out


class FunASRProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> Qwen3Config:
        return self.ctx.get_hf_config(Qwen3Config)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"audio": 1}

    def get_feature_extractor(self, **kwargs: object) -> FunASRFeatureExtractor:
        hf_processor = self.get_hf_processor(**kwargs)
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, FunASRFeatureExtractor)
        return feature_extractor

728
729
730
731
732
733
734
    def get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.get_feature_extractor()
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
        )

735
736
737
738
739
740
741
742
743
744
745
746
747
748
    def get_target_channels(self) -> int:
        return 1


class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|AUDIO|>" * num_audios

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
749
        mm_options: Mapping[str, BaseDummyOptions],
750
    ) -> MultiModalDataDict:
751
        feature_extractor = self.info.get_feature_extractor()
752
753
754
755
756

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

757
        audio_overrides = mm_options.get("audio")
758
759
760

        return {
            "audio": self._get_dummy_audios(
761
762
763
764
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
            ),
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        }


class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if mm_data:
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            input_features=MultiModalFieldConfig.batched("audio"),
            speech_lengths=MultiModalFieldConfig.batched("audio"),
AllenDou's avatar
AllenDou committed
801
            fake_token_lengths=MultiModalFieldConfig.batched("audio"),
802
803
804
805
806
807
808
809
810
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
811
        audio_token_id = processor.audio_token_id
812
813
814

        out_mm_data = out_mm_kwargs.get_data()

AllenDou's avatar
AllenDou committed
815
816
        fake_token_lengths = out_mm_data.get("fake_token_lengths")
        if fake_token_lengths is None:
817
818
            audio_output_lengths = []
        else:
AllenDou's avatar
AllenDou committed
819
            assert isinstance(fake_token_lengths, torch.Tensor)
820

AllenDou's avatar
AllenDou committed
821
            audio_output_lengths = fake_token_lengths.tolist()
822
823

        def get_replacement_qwen2_audio(item_idx: int):
AllenDou's avatar
AllenDou committed
824
            num_features = audio_output_lengths[item_idx]
825
            return [audio_token_id] * num_features
826
827
828
829

        return [
            PromptReplacement(
                modality="audio",
830
                target=[audio_token_id],
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
                replacement=get_replacement_qwen2_audio,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
    FunASRMultiModalProcessor,
    info=FunASRProcessingInfo,
    dummy_inputs=FunASRDummyInputsBuilder,
)
class FunASRForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "linear_q.": "q_proj.",
            "linear_k.": "k_proj.",
            "linear_v.": "v_proj.",
            "linear_out.": "out_proj.",
AllenDou's avatar
AllenDou committed
850
851
852
853
            "audio_adaptor.": "model.encoder.audio_adaptor.",
            "audio_encoder.": "model.encoder.audio_encoder.",
            "llm.model.": "model.decoder.",
            "llm.lm_head": "lm_head",
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        }
    )

    supports_transcription_only = True
    supports_segment_timestamp = True
    supported_languages = ISO639_1_SUPPORTED_LANGS

    @classmethod
    def validate_language(cls, language: str | None) -> str | None:
        if language is None:
            # TODO language should be optional and can be guessed.
            # For now we default to en. See
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
            logger.warning(
                "Defaulting to language='en'. If you wish to transcribe "
                "audio in a different language, pass the `language` field "
                "in the TranscriptionRequest."
            )
            language = "en"
        return super().validate_language(language)

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,  # not needed here
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        if language is None:
            raise ValueError(
                "Language must be specified when creating the funasr prompt"
            )

        funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n"  # noqa: E501
        prompt = {
            "prompt": funasr_prompt,
            "multi_modal_data": {
                "audio": (audio, stt_config.sample_rate),
            },
        }
        return cast(PromptType, prompt)

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        processor = cached_processor_from_config(model_config)

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
    ) -> int | None:
        processor = cached_processor_from_config(model_config)
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

        self.model = FunASRModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        logit_scale = getattr(config, "logit_scale", 1.0)

        if config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
        )
        return decoder_outputs

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        audio_input = self._parse_and_validate_audio_input(**kwargs)

        speech = audio_input["input_features"]
        speech_lengths = audio_input["speech_lengths"]
        enc_output = self.model.get_encoder_outputs(
            speech=speech, speech_lengths=speech_lengths
        )

        return enc_output

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        inputs_embeds = self.model.decoder.embed_input_ids(input_ids)

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=_require_is_multimodal(is_multimodal),
        )

    def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
        input_features = kwargs.pop("input_features", None)
        speech_lengths = kwargs.pop("speech_lengths", None)
AllenDou's avatar
AllenDou committed
991
        fake_token_lengths = kwargs.pop("fake_token_lengths", None)
992
993

        return FunASRAudioInputs(
AllenDou's avatar
AllenDou committed
994
995
996
            input_features=input_features,
            speech_lengths=speech_lengths,
            fake_token_lengths=fake_token_lengths,
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        )

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
        )

        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)