whisper.py 32.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
6
from contextlib import nullcontext
7
from typing import Annotated, Literal, cast
8

9
import numpy as np
10
11
import torch
from torch import nn
12
13
14
15
16
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
17
18
from transformers.models.whisper.modeling_whisper import sinusoids

19
20
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
21
from vllm.attention.layers.cross_attention import CrossAttention
22
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
23
from vllm.config.multimodal import BaseDummyOptions
24
from vllm.distributed import get_tensor_model_parallel_world_size
25
from vllm.inputs.data import PromptType
26
27
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
32
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
from vllm.multimodal import MULTIMODAL_REGISTRY
38
39
40
41
42
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
43
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
44
45
46
47
48
49
from vllm.multimodal.processing import (
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
50
from vllm.multimodal.profiling import BaseDummyInputsBuilder
51
from vllm.transformers_utils.processor import cached_processor_from_config
52
from vllm.utils.jsontree import json_map_leaves
53
from vllm.utils.tensor_schema import TensorSchema, TensorShape
54
from vllm.utils.torch_utils import set_default_torch_dtype
55

56
57
58
59
60
61
62
63
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
64
65
66

logger = init_logger(__name__)

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages

ISO639_1_SUPPORTED_LANGS = {
    "af": "Afrikaans",
    "ar": "Arabic",
    "hy": "Armenian",
    "az": "Azerbaijani",
    "be": "Belarusian",
    "bs": "Bosnian",
    "bg": "Bulgarian",
    "ca": "Catalan",
    "zh": "Chinese",
    "hr": "Croatian",
    "cs": "Czech",
    "da": "Danish",
    "nl": "Dutch",
    "en": "English",
    "et": "Estonian",
    "fi": "Finnish",
    "fr": "French",
    "gl": "Galician",
    "de": "German",
    "el": "Greek",
    "he": "Hebrew",
    "hi": "Hindi",
    "hu": "Hungarian",
    "is": "Icelandic",
    "id": "Indonesian",
    "it": "Italian",
    "ja": "Japanese",
    "kn": "Kannada",
    "kk": "Kazakh",
    "ko": "Korean",
    "lv": "Latvian",
    "lt": "Lithuanian",
    "mk": "Macedonian",
    "ms": "Malay",
    "mr": "Marathi",
    "mi": "Maori",
    "ne": "Nepali",
    "no": "Norwegian",
    "fa": "Persian",
    "pl": "Polish",
    "pt": "Portuguese",
    "ro": "Romanian",
    "ru": "Russian",
    "sr": "Serbian",
    "sk": "Slovak",
    "sl": "Slovenian",
    "es": "Spanish",
    "sw": "Swahili",
    "sv": "Swedish",
    "tl": "Tagalog",
    "ta": "Tamil",
    "th": "Thai",
    "tr": "Turkish",
    "uk": "Ukrainian",
    "ur": "Urdu",
    "vi": "Vietnamese",
126
    "cy": "Welsh",
127
128
}

129

130
131
132
133
134
135
136
137
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

138
    input_features: Annotated[
139
        list[torch.Tensor] | None,
140
141
        TensorShape("b", "nmb", "t"),
    ]
142
143


144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class WhisperEncoderAttention(MultiHeadAttention):
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


172
class WhisperPositionalEmbedding(nn.Embedding):
173
    def __init__(self, num_positions: int, embedding_dim: int):
174
175
176
177
178
179
180
181
182
183
184
185
186
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
187
188
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
214
215
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
216
217
218
219
220
221
222
223
224
225
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
226
227
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
228
229
230
231
232
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
233
234
235
236
237
238
239
240
241
242
243
244
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
Patrick von Platen's avatar
Patrick von Platen committed
245
246
247
248
249
250
251
252
253
254
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
255
256
257
258
259

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
260
        quant_config: QuantizationConfig | None = None,
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

280
        attn_output = self.attn(q, k, v)
281
282
283
284
285
286
287
288
289
290
291
292

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
293
294
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
295
296
297
298
299
300
301
302
303
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
304
            attn_type=AttentionType.ENCODER_DECODER,
305
306
307
308
309
310
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
311
        quant_config: QuantizationConfig | None = None,
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
        self.kv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=0,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
334
        encoder_hidden_states: torch.Tensor | None,
335
336
337
338
339
340
341
342
343
344
345
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

346
        attn_output = self.attn(q, k, v)
347
348
349
350
351
352
353
354
355
356
357
358

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
359
        quant_config: QuantizationConfig | None = None,
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
386
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            attn_type=AttentionType.ENCODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
417
        hidden_states = self.self_attn(hidden_states=hidden_states)
418
419
420
421
422
423
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

424
        hidden_states = cast_overflow_tensors(hidden_states)
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
465
        encoder_hidden_states: torch.Tensor | None,
466
467
468
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
469
        hidden_states = self.self_attn(hidden_states=hidden_states)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
489
490
491
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
492
493
494
495
496
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
497
498
499
500
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
501
502
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
503
504
505
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
506
507
508
509
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

510
511
512
        maybe_fp32_init_ctx = (
            set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
        )
Patrick von Platen's avatar
Patrick von Platen committed
513
514

        with (
515
516
            torch.no_grad(),
            maybe_fp32_init_ctx,
Patrick von Platen's avatar
Patrick von Platen committed
517
        ):
518
            self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
519
            self.embed_positions.weight.copy_(
520
521
                sinusoids(*self.embed_positions.weight.shape)
            )
522

523
    def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
524
        hidden_states = []
525
        input_is_batched = False
526
527
528
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
529
            embeds = embeds.transpose(-1, -2)
530
531
532
            embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
                embeds.dtype
            )
533
            hidden_states.append(embeds)
534
535
536
537
538
539
540
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
        if input_is_batched:
            # Models using WhisperEncoder may handle batching internally.
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
541

542
543
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
544
545
546
547
548
549
550
551
552
553
554
555
556

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
557
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
558

559
560
561
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
562
        self.embed_positions = WhisperPositionalEmbedding(
563
564
            self.max_target_positions, config.d_model
        )
565
566
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
567
568
569
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
570
571
572
573
574
575
576
577
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
578
        encoder_hidden_states: torch.Tensor | None,
579
    ):
580
        inputs_embeds = self.embed_input_ids(input_ids)
581
582
583
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

584
        for decoder_layer in self.layers:
585
586
587
588
589
590
591
592
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

593
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
594
595
596
597
598
599
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
600
601
602
603
604
605
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
606
607
608

    def forward(
        self,
609
        input_ids: torch.Tensor | None,
610
        positions: torch.Tensor,
611
        encoder_outputs: list[torch.Tensor],
612
    ) -> torch.Tensor:
613
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
614
615
616
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
617
            encoder_hidden_states=enc_states,
618
619
620
621
622
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
623
624
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
625
626
        if input_features is None:
            return None
627
        return self.encoder(input_features)
628

629
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
630
631
632
633
634
635
636
637
638
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
639
        loaded_params: set[str] = set()
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
659
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
660
661
662
663
664
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


665
666
667
668
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

669
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
670
671
        return {"audio": 1}

672
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
673
        hf_processor = self.get_hf_processor(**kwargs)
674
675
676
677
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

678
    def get_num_audio_tokens(self) -> int:
679
680
681
682
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
683
684
685
686
687
688
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
689
690
691
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
692
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
693
    ) -> MultiModalDataDict:
694
695
696
697
698
699
        feature_extractor = self.info.get_feature_extractor()

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

700
701
        audio_overrides = mm_options.get("audio") if mm_options else None

702
        return {
703
704
705
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
706
707
708
        }


709
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
710
711
712
713
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

714
715
716
717
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return True

718
719
    def create_encoder_prompt(
        self,
720
        prompt: str | list[int],
721
        mm_data: MultiModalDataDict,
722
    ) -> str | list[int]:
723
724
725
726
727
728
729
730
731
732
733
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
734
        tok_kwargs: Mapping[str, object],
735
736
    ) -> BatchFeature:
        if mm_data:
737
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
738
739
740
741
742
743
744
745
746
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
747
            tok_kwargs=tok_kwargs,
748
749
750
751
752
753
754
755
756
757
758
759
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

760
    def _get_prompt_updates(
761
762
763
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
764
        out_mm_kwargs: MultiModalKwargsItems,
765
    ) -> Sequence[PromptUpdate]:
766
        num_tokens = self.info.get_num_audio_tokens()
767
768
769
770
771
772
773
774
775
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


776
777
778
779
780
781
782
783
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
784
785
786
787
788
789
790
791
792
    packed_modules_mapping = {
        "self_attn.qkv_proj": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
        ],
        "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
    }

793
794
795
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
796

797
798
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
799
    supports_segment_timestamp = True
800
    supported_languages = ISO639_1_SUPPORTED_LANGS
801

802
    @classmethod
803
    def validate_language(cls, language: str | None) -> str | None:
804
805
806
807
        if language is None:
            # TODO language should be optional and can be guessed.
            # For now we default to en. See
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
808
            logger.warning(
809
810
                "Defaulting to language='en'. If you wish to transcribe "
                "audio in a different language, pass the `language` field "
811
812
                "in the TranscriptionRequest."
            )
813
814
            language = "en"
        return super().validate_language(language)
815
816

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
817
    def get_generation_prompt(
818
819
        cls,
        audio: np.ndarray,
820
        model_config: ModelConfig,  # not needed here
821
        stt_config: SpeechToTextConfig,
822
        language: str | None,
823
824
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
825
        to_language: str | None,
826
    ) -> PromptType:
827
828
        if language is None:
            raise ValueError(
829
830
                "Language must be specified when creating the Whisper prompt"
            )
831
832
833
834
835
836
837
838
        prompt = {
            "encoder_prompt": {
                # Whisper does not support encoder prompt.
                "prompt": "",
                "multi_modal_data": {
                    "audio": (audio, stt_config.sample_rate),
                },
            },
839
840
841
842
843
            "decoder_prompt": (
                (f"<|prev|>{request_prompt}" if request_prompt else "")
                + f"<|startoftranscript|><|{language}|>"
                + f"<|{task_type}|><|notimestamps|>"
            ),
844
845
        }
        return cast(PromptType, prompt)
846
847

    @classmethod
848
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
849
850
851
852
853
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

854
    @classmethod
855
    def get_speech_to_text_config(
856
        cls, model_config: ModelConfig, task_type: str
857
    ) -> SpeechToTextConfig:
858
        processor = cached_processor_from_config(model_config)
859
860
861
862
863
864
865

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
866
867
868
869
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
870
        model_config: ModelConfig,
871
    ) -> int | None:
872
        processor = cached_processor_from_config(model_config)
873
874
875
876
877
878
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
879
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
880

881
882
883
884
885
886
887
888
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

        self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
889

890
891
892
893
894
895
896
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
897
        logit_scale = getattr(config, "logit_scale", 1.0)
898
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
899
900
901
902
903

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
904
        encoder_outputs: list[torch.Tensor] | None = None,
905
906
        **kwargs,
    ) -> torch.Tensor:
907
908
        if encoder_outputs is None:
            encoder_outputs = []
909
910
911
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
912
            encoder_outputs=encoder_outputs,
913
914
915
        )
        return decoder_outputs

916
917
918
    def get_language_model(self) -> torch.nn.Module:
        return self.model.decoder

919
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
920
        # Required as part of SupportsMultiModal interface.
921
        audio_input = self._parse_and_validate_audio_input(**kwargs)
922
923
924
925
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
926

927
    def embed_input_ids(
928
929
        self,
        input_ids: torch.Tensor,
930
        multimodal_embeddings: MultiModalEmbeddings | None = None,
931
        *,
932
        is_multimodal: torch.Tensor | None = None,
933
        handle_oov_mm_token: bool = False,
934
    ) -> torch.Tensor:
935
936
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
937
        return self.model.decoder.embed_input_ids(input_ids)
938

939
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
940
941
942
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
943
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
944
945
946

        return WhisperAudioInputs(input_features=input_features)

947
948
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
949
950
        return logits

951
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
952
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
953

954
955
        # add fake zeros bias for k_proj to state_dict
        weights = _create_fake_bias_for_k_proj(weights)
956
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
957
958
959


def _create_fake_bias_for_k_proj(
960
    weights: Iterable[tuple[str, torch.Tensor]],
961
) -> Iterable[tuple[str, torch.Tensor]]:
962
    """
963
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
964
965
966
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
967
        if name.endswith(".k_proj.weight"):
968
969
970
971
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
            yield from [(name, weight), (bias_name, bias)]
        yield name, weight