whisper.py 33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Patrick von Platen's avatar
Patrick von Platen committed
4
import enum
5
import math
6
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
7
from contextlib import nullcontext
8
from typing import Annotated, Literal, cast
9

10
import numpy as np
11
12
import torch
from torch import nn
13
14
15
16
17
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
18
19
from transformers.models.whisper.modeling_whisper import sinusoids

20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.distributed import get_tensor_model_parallel_world_size
24
from vllm.inputs.data import PromptType
25
26
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
27
28
29
30
31
from vllm.model_executor.layers.attention import (
    Attention,
    CrossAttention,
    MMEncoderAttention,
)
32
33
34
35
36
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
40
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Patrick von Platen's avatar
Patrick von Platen committed
41
42
43
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
)
44
from vllm.multimodal import MULTIMODAL_REGISTRY
45
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
50
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
51
from vllm.multimodal.processing import (
52
    BaseDummyInputsBuilder,
53
54
55
56
57
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
58
from vllm.transformers_utils.processor import cached_processor_from_config
59
from vllm.utils.jsontree import json_map_leaves
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61
from vllm.utils.torch_utils import set_default_torch_dtype
62
63
64
from vllm.v1.attention.backend import (
    AttentionType,
)
65

66
67
68
69
70
71
72
73
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
74
75
76

logger = init_logger(__name__)

Patrick von Platen's avatar
Patrick von Platen committed
77
78
79

class WhisperPosEmbedType(enum.Enum):
    SINUSOIDAL = "sinusoidal"
80
    ROPE = "rope"
Patrick von Platen's avatar
Patrick von Platen committed
81
    LEARNED = "learned"
82

83

84
85
86
87
88
89
90
91
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

92
    input_features: Annotated[
93
        list[torch.Tensor] | None,
94
95
        TensorShape("b", "nmb", "t"),
    ]
96
97


98
class WhisperEncoderAttention(MMEncoderAttention):
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


126
class WhisperPositionalEmbedding(nn.Embedding):
127
    def __init__(self, num_positions: int, embedding_dim: int):
128
129
130
131
132
133
134
135
136
137
138
139
140
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
Patrick von Platen's avatar
Patrick von Platen committed
141
        per_layer_sliding_window: int | None = None,
142
143
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
169
170
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
171
172
173
174
175
176
177
178
179
180
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
181
182
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
183
184
185
186
187
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
188
189
190
191
192
193
194
195
196
197
198
199
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
200
            self.attn = Attention(
Patrick von Platen's avatar
Patrick von Platen committed
201
202
203
204
205
206
207
208
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
Patrick von Platen's avatar
Patrick von Platen committed
209
                per_layer_sliding_window=per_layer_sliding_window,
Patrick von Platen's avatar
Patrick von Platen committed
210
            )
211
212
213
214
215

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
216
        quant_config: QuantizationConfig | None = None,
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

236
        attn_output = self.attn(q, k, v)
237
238
239
240
241
242
243
244
245
246
247
248

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
249
250
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
251
252
253
254
255
256
257
258
259
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
260
            attn_type=AttentionType.ENCODER_DECODER,
261
262
263
264
265
266
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
267
        quant_config: QuantizationConfig | None = None,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
        self.kv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=0,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
290
        encoder_hidden_states: torch.Tensor | None,
291
292
293
294
295
296
297
298
299
300
301
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

302
        attn_output = self.attn(q, k, v)
303
304
305
306
307
308
309
310
311
312
313
314

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
315
        quant_config: QuantizationConfig | None = None,
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
342
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
343
344
        super().__init__()
        config = vllm_config.model_config.hf_config
Patrick von Platen's avatar
Patrick von Platen committed
345
        sliding_window = getattr(config, "sliding_window", None)
346
347
348
349
350
351
352
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
353
            attn_type=AttentionType.ENCODER,
Patrick von Platen's avatar
Patrick von Platen committed
354
            per_layer_sliding_window=sliding_window,
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
375
        hidden_states = self.self_attn(hidden_states=hidden_states)
376
377
378
379
380
381
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

382
        hidden_states = cast_overflow_tensors(hidden_states)
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
423
        encoder_hidden_states: torch.Tensor | None,
424
425
426
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
427
        hidden_states = self.self_attn(hidden_states=hidden_states)
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
447
448
449
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
450
451
452
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
Patrick von Platen's avatar
Patrick von Platen committed
453
454
455
456

        self.pos_embed_type = WhisperPosEmbedType(
            getattr(config, "pos_embed", "sinusoidal")
        )
457
458
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
459
460
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

461
462
        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
463
464

        self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
465
466
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
467
468
469
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
470
471
472
473
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

474
        if self.pos_embed_type not in (
Patrick von Platen's avatar
Patrick von Platen committed
475
476
            WhisperPosEmbedType.SINUSOIDAL,
            WhisperPosEmbedType.LEARNED,
Patrick von Platen's avatar
Patrick von Platen committed
477
        ):
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            raise ValueError(
                "Only sinusoidal or learned position embeddings are supported "
                f"for non-causal models, but got {self.pos_embed_type}"
            )

        maybe_fp32_init_ctx = (
            set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
        )

        with (
            torch.no_grad(),
            maybe_fp32_init_ctx,
        ):
            self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
            self.embed_positions.weight.copy_(
                sinusoids(*self.embed_positions.weight.shape)
494
            )
495

496
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
497
498
        self, input_features: torch.Tensor | list[torch.Tensor]
    ) -> torch.Tensor:
499
        hidden_states = []
500
        input_is_batched = False
501
502
503
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
504

505
506
507
508
            embeds = embeds.transpose(-1, -2)
            embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
                embeds.dtype
            )
Patrick von Platen's avatar
Patrick von Platen committed
509

510
            hidden_states.append(embeds)
511
512
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
513
        if input_is_batched:
514
515
516
517
            # Models using WhisperEncoder may handle batching internally.
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
518

519
520
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
521
522
523
524
525

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


526
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
527
528
529
530
531
532
533
534
class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
535
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
536

537
538
539
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
540
        self.embed_positions = WhisperPositionalEmbedding(
541
542
            self.max_target_positions, config.d_model
        )
543
544
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
545
546
547
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
548
549
550
551
552
553
554
555
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
556
        encoder_hidden_states: torch.Tensor | None,
557
    ):
558
        inputs_embeds = self.embed_input_ids(input_ids)
559
560
561
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

562
        for decoder_layer in self.layers:
563
564
565
566
567
568
569
570
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

571
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
572
573
574
575
576
577
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
578
579
580
581
582
583
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
584
585
586

    def forward(
        self,
587
        input_ids: torch.Tensor | None,
588
        positions: torch.Tensor,
589
        encoder_outputs: list[torch.Tensor],
590
    ) -> torch.Tensor:
591
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
592
593
594
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
595
            encoder_hidden_states=enc_states,
596
597
598
599
600
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
601
602
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
603
604
        if input_features is None:
            return None
605
        return self.encoder(input_features)
606

607
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
608
609
610
611
612
613
614
615
616
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
617
        loaded_params: set[str] = set()
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
637
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
638
639
640
641
642
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


643
644
645
646
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

647
648
649
650
    @property
    def skip_prompt_length_check(self) -> bool:
        return True  # Because the encoder prompt is padded

651
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
652
653
        return {"audio": 1}

654
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
655
        hf_processor = self.get_hf_processor(**kwargs)
656
657
658
659
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

660
661
662
663
    def get_target_channels(self) -> int:
        """Return target audio channels for Whisper models (mono)."""
        return 1

664
    def get_num_audio_tokens(self) -> int:
665
666
667
668
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
669
670
671
672
673
674
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
675
676
677
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
678
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
679
    ) -> MultiModalDataDict:
680
681
682
683
684
685
        feature_extractor = self.info.get_feature_extractor()

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

686
687
        audio_overrides = mm_options.get("audio") if mm_options else None

688
        return {
689
690
691
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
692
693
694
        }


695
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
696
697
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
698
699
700
701
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.info.get_target_channels(),
        )
702
703
704

    def create_encoder_prompt(
        self,
705
        prompt: str | list[int],
706
        mm_data: MultiModalDataDict,
707
    ) -> str | list[int]:
708
709
710
711
712
713
714
715
716
717
718
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
719
        tok_kwargs: Mapping[str, object],
720
721
    ) -> BatchFeature:
        if mm_data:
722
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
723
724
725
726
727
728
729
730
731
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
732
            tok_kwargs=tok_kwargs,
733
734
735
736
737
738
739
740
741
742
743
744
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

745
    def _get_prompt_updates(
746
747
748
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
749
        out_mm_kwargs: MultiModalKwargsItems,
750
    ) -> Sequence[PromptUpdate]:
751
        num_tokens = self.info.get_num_audio_tokens()
752
753
754
755
756
757
758
759
760
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


761
762
763
764
765
766
767
768
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
769
770
771
772
773
774
775
776
777
    packed_modules_mapping = {
        "self_attn.qkv_proj": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
        ],
        "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
    }

778
779
780
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
781

782
783
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
784
    supports_segment_timestamp = True
785
    supported_languages = ISO639_1_SUPPORTED_LANGS
786

787
    @classmethod
788
    def validate_language(cls, language: str | None) -> str | None:
789
790
791
792
        if language is None:
            # TODO language should be optional and can be guessed.
            # For now we default to en. See
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
793
            logger.warning(
794
795
                "Defaulting to language='en'. If you wish to transcribe "
                "audio in a different language, pass the `language` field "
796
797
                "in the TranscriptionRequest."
            )
798
799
            language = "en"
        return super().validate_language(language)
800
801

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
802
    def get_generation_prompt(
803
804
        cls,
        audio: np.ndarray,
805
        model_config: ModelConfig,  # not needed here
806
        stt_config: SpeechToTextConfig,
807
        language: str | None,
808
809
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
810
        to_language: str | None,
811
    ) -> PromptType:
812
813
        if language is None:
            raise ValueError(
814
815
                "Language must be specified when creating the Whisper prompt"
            )
816
817
818
819
820
821
822
823
        prompt = {
            "encoder_prompt": {
                # Whisper does not support encoder prompt.
                "prompt": "",
                "multi_modal_data": {
                    "audio": (audio, stt_config.sample_rate),
                },
            },
824
825
826
827
828
            "decoder_prompt": (
                (f"<|prev|>{request_prompt}" if request_prompt else "")
                + f"<|startoftranscript|><|{language}|>"
                + f"<|{task_type}|><|notimestamps|>"
            ),
829
830
        }
        return cast(PromptType, prompt)
831
832

    @classmethod
833
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
834
835
836
837
838
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

839
    @classmethod
840
    def get_speech_to_text_config(
841
        cls, model_config: ModelConfig, task_type: str
842
    ) -> SpeechToTextConfig:
843
        processor = cached_processor_from_config(model_config)
844
845
846
847
848
849
850

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
851
852
853
854
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
855
        model_config: ModelConfig,
856
    ) -> int | None:
857
        processor = cached_processor_from_config(model_config)
858
859
860
861
862
863
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
864
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
865

866
867
868
869
870
871
872
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

873
874
875
876
877
878
        with self._mark_composite_model(
            vllm_config,
            language_targets=WhisperDecoder,
            tower_targets={"audio": WhisperEncoder},
        ):
            self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
879

880
881
882
883
884
885
886
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
887
        logit_scale = getattr(config, "logit_scale", 1.0)
888
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
889
890
891
892
893

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
894
        encoder_outputs: list[torch.Tensor] | None = None,
895
896
        **kwargs,
    ) -> torch.Tensor:
897
898
        if encoder_outputs is None:
            encoder_outputs = []
899
900
901
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
902
            encoder_outputs=encoder_outputs,
903
904
905
        )
        return decoder_outputs

906
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
907
        # Required as part of SupportsMultiModal interface.
908
        audio_input = self._parse_and_validate_audio_input(**kwargs)
909
910
911
912
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
913

914
    def embed_input_ids(
915
916
        self,
        input_ids: torch.Tensor,
917
        multimodal_embeddings: MultiModalEmbeddings | None = None,
918
        *,
919
        is_multimodal: torch.Tensor | None = None,
920
        handle_oov_mm_token: bool = False,
921
    ) -> torch.Tensor:
922
923
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
924
        return self.model.decoder.embed_input_ids(input_ids)
925

926
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
927
928
929
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
930
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
931
932
933

        return WhisperAudioInputs(input_features=input_features)

934
935
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
936
937
        return logits

938
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
939
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
940

941
        # add fake zeros bias for k_proj to state_dict
942
        weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
943
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
944
945
946


def _create_fake_bias_for_k_proj(
947
    weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
948
) -> Iterable[tuple[str, torch.Tensor]]:
949
    """
950
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
951
952
953
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
954
        if name.endswith(fake_bias_key_name):
955
956
957
958
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
            yield from [(name, weight), (bias_name, bias)]
        yield name, weight