whisper.py 35.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Patrick von Platen's avatar
Patrick von Platen committed
4
import enum
5
import math
6
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
7
from contextlib import nullcontext
8
from typing import Annotated
9

10
import numpy as np
11
12
import torch
from torch import nn
13
14
15
16
17
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
18
19
from transformers.models.whisper.modeling_whisper import sinusoids

20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.config.speech_to_text import SpeechToTextParams
24
from vllm.distributed import get_tensor_model_parallel_world_size
25
26
27
28
29
30
from vllm.inputs import (
    ExplicitEncoderDecoderPrompt,
    MultiModalDataDict,
    PromptType,
    TextPrompt,
)
31
32
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
36
37
from vllm.model_executor.layers.attention import (
    Attention,
    CrossAttention,
    MMEncoderAttention,
)
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
40
    MergedColumnParallelLinear,
41
42
43
    QKVParallelLinear,
    RowParallelLinear,
)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
47
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Patrick von Platen's avatar
Patrick von Platen committed
48
49
50
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
)
51
from vllm.multimodal import MULTIMODAL_REGISTRY
52
53
54
55
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
56
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
57
from vllm.multimodal.processing import (
58
    BaseDummyInputsBuilder,
59
60
61
62
63
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
64
from vllm.renderers import TokenizeParams
65
from vllm.transformers_utils.processor import cached_processor_from_config
66
from vllm.utils.jsontree import json_map_leaves
67
from vllm.utils.tensor_schema import TensorSchema, TensorShape
68
from vllm.utils.torch_utils import set_default_torch_dtype
69
70
71
from vllm.v1.attention.backend import (
    AttentionType,
)
72

73
74
from .interfaces import (
    MultiModalEmbeddings,
75
    SupportsLoRA,
76
77
78
    SupportsMultiModal,
    SupportsTranscription,
)
79
80
81
82
83
84
85
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
86
87
88

logger = init_logger(__name__)

Patrick von Platen's avatar
Patrick von Platen committed
89
90
91

class WhisperPosEmbedType(enum.Enum):
    SINUSOIDAL = "sinusoidal"
92
    ROPE = "rope"
Patrick von Platen's avatar
Patrick von Platen committed
93
    LEARNED = "learned"
94

95

96
97
98
99
100
101
102
103
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

104
    input_features: Annotated[
105
        list[torch.Tensor] | None,
106
107
        TensorShape("b", "nmb", "t"),
    ]
108
109


110
class WhisperEncoderAttention(MMEncoderAttention):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


138
class WhisperPositionalEmbedding(nn.Embedding):
139
    def __init__(self, num_positions: int, embedding_dim: int):
140
141
142
143
144
145
146
147
148
149
150
151
152
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
Patrick von Platen's avatar
Patrick von Platen committed
153
        per_layer_sliding_window: int | None = None,
154
155
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
181
182
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
183
184
185
186
187
188
189
190
191
192
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
193
194
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
195
196
197
198
199
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
200
201
202
203
204
205
206
207
208
209
210
211
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
212
            self.attn = Attention(
Patrick von Platen's avatar
Patrick von Platen committed
213
214
215
216
217
218
219
220
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
Patrick von Platen's avatar
Patrick von Platen committed
221
                per_layer_sliding_window=per_layer_sliding_window,
Patrick von Platen's avatar
Patrick von Platen committed
222
            )
223
224
225
226
227

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
228
        quant_config: QuantizationConfig | None = None,
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

248
        attn_output = self.attn(q, k, v)
249
250
251
252
253
254
255
256
257
258
259
260

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
261
262
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
263
264
265
266
267
268
269
270
271
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
272
            attn_type=AttentionType.ENCODER_DECODER,
273
274
275
276
277
278
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
279
        quant_config: QuantizationConfig | None = None,
280
281
282
283
284
285
286
287
288
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
289
290
291
292
293
294
        # Use MergedColumnParallelLinear for K and V projections.
        # This enables LoRA support via MergedColumnParallelLinearWithLoRA
        # which handles 2-slice configurations.
        self.kv_proj = MergedColumnParallelLinear(
            input_size=embed_dim,
            output_sizes=[embed_dim, embed_dim],
295
296
297
298
299
300
301
302
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
303
        encoder_hidden_states: torch.Tensor | None,
304
305
306
307
308
309
310
311
312
313
314
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

315
        attn_output = self.attn(q, k, v)
316
317
318
319
320
321
322
323
324
325
326
327

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
328
        quant_config: QuantizationConfig | None = None,
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
355
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
356
357
        super().__init__()
        config = vllm_config.model_config.hf_config
Patrick von Platen's avatar
Patrick von Platen committed
358
        sliding_window = getattr(config, "sliding_window", None)
359
360
361
362
363
364
365
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
366
            attn_type=AttentionType.ENCODER,
Patrick von Platen's avatar
Patrick von Platen committed
367
            per_layer_sliding_window=sliding_window,
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
388
        hidden_states = self.self_attn(hidden_states=hidden_states)
389
390
391
392
393
394
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

395
        hidden_states = cast_overflow_tensors(hidden_states)
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
436
        encoder_hidden_states: torch.Tensor | None,
437
438
439
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
440
        hidden_states = self.self_attn(hidden_states=hidden_states)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
460
461
462
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
463
464
465
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
Patrick von Platen's avatar
Patrick von Platen committed
466
467
468
469

        self.pos_embed_type = WhisperPosEmbedType(
            getattr(config, "pos_embed", "sinusoidal")
        )
470
471
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
472
473
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

474
475
        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
476
477

        self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
478
479
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
480
481
482
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
483
484
485
486
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

487
        if self.pos_embed_type not in (
Patrick von Platen's avatar
Patrick von Platen committed
488
489
            WhisperPosEmbedType.SINUSOIDAL,
            WhisperPosEmbedType.LEARNED,
Patrick von Platen's avatar
Patrick von Platen committed
490
        ):
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            raise ValueError(
                "Only sinusoidal or learned position embeddings are supported "
                f"for non-causal models, but got {self.pos_embed_type}"
            )

        maybe_fp32_init_ctx = (
            set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
        )

        with (
            torch.no_grad(),
            maybe_fp32_init_ctx,
        ):
            self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
            self.embed_positions.weight.copy_(
                sinusoids(*self.embed_positions.weight.shape)
507
            )
508

509
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
510
511
        self, input_features: torch.Tensor | list[torch.Tensor]
    ) -> torch.Tensor:
512
        hidden_states = []
513
        input_is_batched = False
514
515
516
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
517

518
519
520
521
            embeds = embeds.transpose(-1, -2)
            embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
                embeds.dtype
            )
Patrick von Platen's avatar
Patrick von Platen committed
522

523
            hidden_states.append(embeds)
524
525
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
526
        if input_is_batched:
527
528
529
530
            # Models using WhisperEncoder may handle batching internally.
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
531

532
533
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
534
535
536
537
538

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


539
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
540
541
542
543
544
545
546
547
class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
548
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
549

550
551
552
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
553
        self.embed_positions = WhisperPositionalEmbedding(
554
555
            self.max_target_positions, config.d_model
        )
556
557
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
558
559
560
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
561
562
563
564
565
566
567
568
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
569
        encoder_hidden_states: torch.Tensor | None,
570
    ):
571
        inputs_embeds = self.embed_input_ids(input_ids)
572
573
574
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

575
        for decoder_layer in self.layers:
576
577
578
579
580
581
582
583
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

584
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
585
586
587
588
589
590
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
591
592
593
594
595
596
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
597
598
599

    def forward(
        self,
600
        input_ids: torch.Tensor | None,
601
        positions: torch.Tensor,
602
        encoder_outputs: list[torch.Tensor],
603
    ) -> torch.Tensor:
604
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
605
606
607
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
608
            encoder_hidden_states=enc_states,
609
610
611
612
613
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
614
615
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
616
617
        if input_features is None:
            return None
618
        return self.encoder(input_features)
619

620
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
621
622
623
624
625
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
626
627
628
            # MergedColumnParallelLinear uses integer indices (0, 1)
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
629
630
        ]
        params_dict = dict(self.named_parameters())
631
        loaded_params: set[str] = set()
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
651
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
652
653
654
655
656
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


657
658
659
660
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

661
662
663
664
665
666
    def get_default_tok_params(self) -> TokenizeParams:
        # Special tokens should be provided by the user based on the
        # task and language of their request. Also needed to avoid
        # appending an EOS token to the prompt which disrupts generation.
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

667
668
669
670
671
672
673
674
675
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

676
677
678
679
    @property
    def skip_prompt_length_check(self) -> bool:
        return True  # Because the encoder prompt is padded

680
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
681
682
        return {"audio": 1}

683
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
684
        hf_processor = self.get_hf_processor(**kwargs)
685
686
687
688
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

689
690
691
692
    def get_target_channels(self) -> int:
        """Return target audio channels for Whisper models (mono)."""
        return 1

693
    def get_num_audio_tokens(self) -> int:
694
695
696
697
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
698
699
700
701
702
703
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
704
705
706
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
707
        mm_options: Mapping[str, BaseDummyOptions],
708
    ) -> MultiModalDataDict:
709
        feature_extractor = self.info.get_feature_extractor()
710
711
712
713
714

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

715
        audio_overrides = mm_options.get("audio")
716

717
        return {
718
            "audio": self._get_dummy_audios(
719
720
721
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
722
            )
723
724
725
        }


726
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
727
728
    def create_encoder_prompt(
        self,
729
        prompt: str | list[int],
730
        mm_items: MultiModalDataItems,
731
    ) -> str | list[int]:
732
733
734
735
736
737
738
739
740
741
742
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
743
        tok_kwargs: Mapping[str, object],
744
745
    ) -> BatchFeature:
        if mm_data:
746
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
747
748
749
750
751
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
752
753
754
755
756
757
758
759
        # The HF WhisperProcessor passes **kwargs to both the tokenizer
        # and the feature extractor. Text-tokenizer kwargs like
        # `truncation` and `max_length` must be removed when audio data
        # is present, otherwise the feature extractor interprets
        # `max_length` as raw audio samples and truncates the audio.
        tok_kwargs = {
            k: v for k, v in tok_kwargs.items() if k not in ("truncation", "max_length")
        }
760
761
762
763
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
764
            tok_kwargs=tok_kwargs,
765
766
767
768
769
770
771
772
773
774
775
776
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

777
    def _get_prompt_updates(
778
779
780
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
781
        out_mm_kwargs: MultiModalKwargsItems,
782
    ) -> Sequence[PromptUpdate]:
783
        num_tokens = self.info.get_num_audio_tokens()
784
785
786
787
788
789
790
791
792
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


793
794
795
796
797
798
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
799
800
801
    nn.Module,
    SupportsTranscription,
    SupportsMultiModal,
802
    SupportsLoRA,
803
):
804
    # LoRA-specific attributes
805
    packed_modules_mapping = {
806
807
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "kv_proj": ["k_proj", "v_proj"],
808
809
    }

810
811
812
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
813

814
815
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
816
    supports_segment_timestamp = True
817
    supports_explicit_language_detection = True
818
    supported_languages = ISO639_1_SUPPORTED_LANGS
819

820
    @classmethod
821
    def validate_language(cls, language: str | None) -> str | None:
822
        if language is None:
823
824
825
            logger.debug(
                "No language specified. Language will be auto-detected "
                "from audio. To skip detection, pass the `language` field "
826
827
                "in the TranscriptionRequest."
            )
828
            return None
829
        return super().validate_language(language)
830
831

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
832
    def get_generation_prompt(
833
        cls,
834
        stt_params: SpeechToTextParams,
835
    ) -> PromptType:
836
837
838
839
840
841
        audio = stt_params.audio
        stt_config = stt_params.stt_config
        language = stt_params.language
        task_type = stt_params.task_type
        request_prompt = stt_params.request_prompt

842
843
        if language is None:
            raise ValueError(
844
845
                "Language must be specified when creating the Whisper prompt"
            )
846
847
848
849
850
851
852
853
854

        decoder_text = (
            f"<|prev|>{request_prompt}" if request_prompt else ""
        ) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"

        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",  # Whisper does not support encoder prompt.
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
855
            ),
856
857
            decoder_prompt=TextPrompt(prompt=decoder_text),
        )
858

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    @classmethod
    def get_language_token_ids(
        cls,
        tokenizer: object,
    ) -> list[int]:
        """Return token IDs for all supported language tokens.

        Used with ``SamplingParams.allowed_token_ids`` to constrain
        language detection to only produce valid language tokens.
        """
        token_ids = [
            tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
            for lang_code in cls.supported_languages
        ]
        return token_ids

    @classmethod
    def get_language_detection_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
    ) -> PromptType:
        """Return a prompt that elicits a single language token from Whisper.

        Feed only ``<|startoftranscript|>`` as the decoder input so the model
        predicts the most likely language token (e.g. ``<|de|>``).
        """
        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
            ),
            decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
        )

    @classmethod
    def parse_language_detection_output(
        cls,
        token_ids: list[int],
        tokenizer: object,
    ) -> str | None:
        """Parse the language token predicted by Whisper.

        Decodes the first token ID and extracts the language code from the
        ``<|xx|>`` format. Expects a valid language token from constrained generation.
        """

        decoded = tokenizer.decode(
            [token_ids[0]],
            skip_special_tokens=False,
        )
        # Whisper language tokens have the form <|xx|>
        assert decoded.startswith("<|") and decoded.endswith("|>")
        lang_code = decoded[2:-2]
        assert lang_code in cls.supported_languages
        return lang_code

916
    @classmethod
917
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
918
919
920
921
922
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

923
    @classmethod
924
    def get_speech_to_text_config(
925
        cls, model_config: ModelConfig, task_type: str
926
    ) -> SpeechToTextConfig:
927
        processor = cached_processor_from_config(model_config)
928
929
930
931
932
933
934

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
935
936
937
938
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
939
        model_config: ModelConfig,
940
    ) -> int | None:
941
        processor = cached_processor_from_config(model_config)
942
943
944
945
946
947
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
948
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
949

950
951
952
953
954
955
956
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

957
958
959
960
961
962
        with self._mark_composite_model(
            vllm_config,
            language_targets=WhisperDecoder,
            tower_targets={"audio": WhisperEncoder},
        ):
            self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
963

964
965
966
967
968
969
970
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
971
        logit_scale = getattr(config, "logit_scale", 1.0)
972
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
973
974
975
976
977

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
978
        encoder_outputs: list[torch.Tensor] | None = None,
979
980
        **kwargs,
    ) -> torch.Tensor:
981
982
        if encoder_outputs is None:
            encoder_outputs = []
983
984
985
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
986
            encoder_outputs=encoder_outputs,
987
988
989
        )
        return decoder_outputs

990
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
991
        # Required as part of SupportsMultiModal interface.
992
        audio_input = self._parse_and_validate_audio_input(**kwargs)
993
994
995
996
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
997

998
    def embed_input_ids(
999
1000
        self,
        input_ids: torch.Tensor,
1001
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1002
        *,
1003
        is_multimodal: torch.Tensor | None = None,
1004
    ) -> torch.Tensor:
1005
1006
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
1007
        return self.model.decoder.embed_input_ids(input_ids)
1008

1009
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
1010
1011
1012
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
1013
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
1014
1015
1016

        return WhisperAudioInputs(input_features=input_features)

1017
1018
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
1019
1020
        return logits

1021
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1022
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
1023

1024
        # add fake zeros bias for k_proj to state_dict
1025
        weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
1026
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1027
1028
1029


def _create_fake_bias_for_k_proj(
1030
    weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
1031
) -> Iterable[tuple[str, torch.Tensor]]:
1032
    """
1033
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
1034
1035
1036
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
1037
        yield name, weight
1038
        if name.endswith(fake_bias_key_name):
1039
1040
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
1041
            yield bias_name, bias