"vllm/vscode:/vscode.git/clone" did not exist on "be633fba0f8fc41b19a774a89ad055e54865af53"
whisper.py 35.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Patrick von Platen's avatar
Patrick von Platen committed
4
import enum
5
import math
6
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
7
from contextlib import nullcontext
8
from typing import Annotated, Literal
9

10
import numpy as np
11
12
import torch
from torch import nn
13
14
15
16
17
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
18
19
from transformers.models.whisper.modeling_whisper import sinusoids

20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.distributed import get_tensor_model_parallel_world_size
24
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
25
26
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
27
28
29
30
31
from vllm.model_executor.layers.attention import (
    Attention,
    CrossAttention,
    MMEncoderAttention,
)
32
33
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
34
    MergedColumnParallelLinear,
35
36
37
    QKVParallelLinear,
    RowParallelLinear,
)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
41
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Patrick von Platen's avatar
Patrick von Platen committed
42
43
44
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
)
45
from vllm.multimodal import MULTIMODAL_REGISTRY
46
47
48
49
50
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
51
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
52
from vllm.multimodal.processing import (
53
    BaseDummyInputsBuilder,
54
55
56
57
58
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
59
from vllm.renderers import TokenizeParams
60
from vllm.transformers_utils.processor import cached_processor_from_config
61
from vllm.utils.jsontree import json_map_leaves
62
from vllm.utils.tensor_schema import TensorSchema, TensorShape
63
from vllm.utils.torch_utils import set_default_torch_dtype
64
65
66
from vllm.v1.attention.backend import (
    AttentionType,
)
67

68
69
from .interfaces import (
    MultiModalEmbeddings,
70
    SupportsLoRA,
71
72
73
    SupportsMultiModal,
    SupportsTranscription,
)
74
75
76
77
78
79
80
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
81
82
83

logger = init_logger(__name__)

Patrick von Platen's avatar
Patrick von Platen committed
84
85
86

class WhisperPosEmbedType(enum.Enum):
    SINUSOIDAL = "sinusoidal"
87
    ROPE = "rope"
Patrick von Platen's avatar
Patrick von Platen committed
88
    LEARNED = "learned"
89

90

91
92
93
94
95
96
97
98
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

99
    input_features: Annotated[
100
        list[torch.Tensor] | None,
101
102
        TensorShape("b", "nmb", "t"),
    ]
103
104


105
class WhisperEncoderAttention(MMEncoderAttention):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


133
class WhisperPositionalEmbedding(nn.Embedding):
134
    def __init__(self, num_positions: int, embedding_dim: int):
135
136
137
138
139
140
141
142
143
144
145
146
147
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
Patrick von Platen's avatar
Patrick von Platen committed
148
        per_layer_sliding_window: int | None = None,
149
150
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
176
177
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
178
179
180
181
182
183
184
185
186
187
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
188
189
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
190
191
192
193
194
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
195
196
197
198
199
200
201
202
203
204
205
206
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
207
            self.attn = Attention(
Patrick von Platen's avatar
Patrick von Platen committed
208
209
210
211
212
213
214
215
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
Patrick von Platen's avatar
Patrick von Platen committed
216
                per_layer_sliding_window=per_layer_sliding_window,
Patrick von Platen's avatar
Patrick von Platen committed
217
            )
218
219
220
221
222

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
223
        quant_config: QuantizationConfig | None = None,
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

243
        attn_output = self.attn(q, k, v)
244
245
246
247
248
249
250
251
252
253
254
255

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
256
257
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
258
259
260
261
262
263
264
265
266
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
267
            attn_type=AttentionType.ENCODER_DECODER,
268
269
270
271
272
273
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
274
        quant_config: QuantizationConfig | None = None,
275
276
277
278
279
280
281
282
283
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
284
285
286
287
288
289
        # Use MergedColumnParallelLinear for K and V projections.
        # This enables LoRA support via MergedColumnParallelLinearWithLoRA
        # which handles 2-slice configurations.
        self.kv_proj = MergedColumnParallelLinear(
            input_size=embed_dim,
            output_sizes=[embed_dim, embed_dim],
290
291
292
293
294
295
296
297
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
298
        encoder_hidden_states: torch.Tensor | None,
299
300
301
302
303
304
305
306
307
308
309
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

310
        attn_output = self.attn(q, k, v)
311
312
313
314
315
316
317
318
319
320
321
322

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
323
        quant_config: QuantizationConfig | None = None,
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
350
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
351
352
        super().__init__()
        config = vllm_config.model_config.hf_config
Patrick von Platen's avatar
Patrick von Platen committed
353
        sliding_window = getattr(config, "sliding_window", None)
354
355
356
357
358
359
360
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
361
            attn_type=AttentionType.ENCODER,
Patrick von Platen's avatar
Patrick von Platen committed
362
            per_layer_sliding_window=sliding_window,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
383
        hidden_states = self.self_attn(hidden_states=hidden_states)
384
385
386
387
388
389
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

390
        hidden_states = cast_overflow_tensors(hidden_states)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
431
        encoder_hidden_states: torch.Tensor | None,
432
433
434
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
435
        hidden_states = self.self_attn(hidden_states=hidden_states)
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
455
456
457
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
458
459
460
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
Patrick von Platen's avatar
Patrick von Platen committed
461
462
463
464

        self.pos_embed_type = WhisperPosEmbedType(
            getattr(config, "pos_embed", "sinusoidal")
        )
465
466
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
467
468
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

469
470
        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
471
472

        self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
473
474
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
475
476
477
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
478
479
480
481
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

482
        if self.pos_embed_type not in (
Patrick von Platen's avatar
Patrick von Platen committed
483
484
            WhisperPosEmbedType.SINUSOIDAL,
            WhisperPosEmbedType.LEARNED,
Patrick von Platen's avatar
Patrick von Platen committed
485
        ):
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            raise ValueError(
                "Only sinusoidal or learned position embeddings are supported "
                f"for non-causal models, but got {self.pos_embed_type}"
            )

        maybe_fp32_init_ctx = (
            set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
        )

        with (
            torch.no_grad(),
            maybe_fp32_init_ctx,
        ):
            self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
            self.embed_positions.weight.copy_(
                sinusoids(*self.embed_positions.weight.shape)
502
            )
503

504
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
505
506
        self, input_features: torch.Tensor | list[torch.Tensor]
    ) -> torch.Tensor:
507
        hidden_states = []
508
        input_is_batched = False
509
510
511
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
512

513
514
515
516
            embeds = embeds.transpose(-1, -2)
            embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
                embeds.dtype
            )
Patrick von Platen's avatar
Patrick von Platen committed
517

518
            hidden_states.append(embeds)
519
520
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
521
        if input_is_batched:
522
523
524
525
            # Models using WhisperEncoder may handle batching internally.
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
526

527
528
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
529
530
531
532
533

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


534
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
535
536
537
538
539
540
541
542
class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
543
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
544

545
546
547
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
548
        self.embed_positions = WhisperPositionalEmbedding(
549
550
            self.max_target_positions, config.d_model
        )
551
552
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
553
554
555
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
556
557
558
559
560
561
562
563
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
564
        encoder_hidden_states: torch.Tensor | None,
565
    ):
566
        inputs_embeds = self.embed_input_ids(input_ids)
567
568
569
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

570
        for decoder_layer in self.layers:
571
572
573
574
575
576
577
578
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

579
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
580
581
582
583
584
585
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
586
587
588
589
590
591
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
592
593
594

    def forward(
        self,
595
        input_ids: torch.Tensor | None,
596
        positions: torch.Tensor,
597
        encoder_outputs: list[torch.Tensor],
598
    ) -> torch.Tensor:
599
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
600
601
602
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
603
            encoder_hidden_states=enc_states,
604
605
606
607
608
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
609
610
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
611
612
        if input_features is None:
            return None
613
        return self.encoder(input_features)
614

615
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
616
617
618
619
620
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
621
622
623
            # MergedColumnParallelLinear uses integer indices (0, 1)
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
624
625
        ]
        params_dict = dict(self.named_parameters())
626
        loaded_params: set[str] = set()
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
646
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
647
648
649
650
651
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


652
653
654
655
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

656
657
658
659
660
661
    def get_default_tok_params(self) -> TokenizeParams:
        # Special tokens should be provided by the user based on the
        # task and language of their request. Also needed to avoid
        # appending an EOS token to the prompt which disrupts generation.
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

662
663
664
665
666
667
668
669
670
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

671
672
673
674
    @property
    def skip_prompt_length_check(self) -> bool:
        return True  # Because the encoder prompt is padded

675
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
676
677
        return {"audio": 1}

678
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
679
        hf_processor = self.get_hf_processor(**kwargs)
680
681
682
683
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

684
685
686
687
    def get_target_channels(self) -> int:
        """Return target audio channels for Whisper models (mono)."""
        return 1

688
    def get_num_audio_tokens(self) -> int:
689
690
691
692
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
693
694
695
696
697
698
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
699
700
701
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
702
        mm_options: Mapping[str, BaseDummyOptions],
703
    ) -> MultiModalDataDict:
704
        feature_extractor = self.info.get_feature_extractor()
705
706
707
708
709

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

710
        audio_overrides = mm_options.get("audio")
711

712
        return {
713
            "audio": self._get_dummy_audios(
714
715
716
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
717
            )
718
719
720
        }


721
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
722
723
    def create_encoder_prompt(
        self,
724
        prompt: str | list[int],
725
        mm_items: MultiModalDataItems,
726
    ) -> str | list[int]:
727
728
729
730
731
732
733
734
735
736
737
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
738
        tok_kwargs: Mapping[str, object],
739
740
    ) -> BatchFeature:
        if mm_data:
741
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
742
743
744
745
746
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
747
748
749
750
751
752
753
754
        # The HF WhisperProcessor passes **kwargs to both the tokenizer
        # and the feature extractor. Text-tokenizer kwargs like
        # `truncation` and `max_length` must be removed when audio data
        # is present, otherwise the feature extractor interprets
        # `max_length` as raw audio samples and truncates the audio.
        tok_kwargs = {
            k: v for k, v in tok_kwargs.items() if k not in ("truncation", "max_length")
        }
755
756
757
758
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
759
            tok_kwargs=tok_kwargs,
760
761
762
763
764
765
766
767
768
769
770
771
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

772
    def _get_prompt_updates(
773
774
775
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
776
        out_mm_kwargs: MultiModalKwargsItems,
777
    ) -> Sequence[PromptUpdate]:
778
        num_tokens = self.info.get_num_audio_tokens()
779
780
781
782
783
784
785
786
787
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


788
789
790
791
792
793
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
794
795
796
    nn.Module,
    SupportsTranscription,
    SupportsMultiModal,
797
    SupportsLoRA,
798
):
799
    # LoRA-specific attributes
800
    packed_modules_mapping = {
801
802
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "kv_proj": ["k_proj", "v_proj"],
803
804
    }

805
806
807
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
808

809
810
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
811
    supports_segment_timestamp = True
812
    supports_explicit_language_detection = True
813
    supported_languages = ISO639_1_SUPPORTED_LANGS
814

815
    @classmethod
816
    def validate_language(cls, language: str | None) -> str | None:
817
        if language is None:
818
819
820
            logger.debug(
                "No language specified. Language will be auto-detected "
                "from audio. To skip detection, pass the `language` field "
821
822
                "in the TranscriptionRequest."
            )
823
            return None
824
        return super().validate_language(language)
825
826

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
827
    def get_generation_prompt(
828
829
        cls,
        audio: np.ndarray,
830
        model_config: ModelConfig,  # not needed here
831
        stt_config: SpeechToTextConfig,
832
        language: str | None,
833
834
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
835
        to_language: str | None,
836
    ) -> PromptType:
837
838
        if language is None:
            raise ValueError(
839
840
                "Language must be specified when creating the Whisper prompt"
            )
841
842
843
844
845
846
847
848
849

        decoder_text = (
            f"<|prev|>{request_prompt}" if request_prompt else ""
        ) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"

        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",  # Whisper does not support encoder prompt.
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
850
            ),
851
852
            decoder_prompt=TextPrompt(prompt=decoder_text),
        )
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    @classmethod
    def get_language_token_ids(
        cls,
        tokenizer: object,
    ) -> list[int]:
        """Return token IDs for all supported language tokens.

        Used with ``SamplingParams.allowed_token_ids`` to constrain
        language detection to only produce valid language tokens.
        """
        token_ids = [
            tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
            for lang_code in cls.supported_languages
        ]
        return token_ids

    @classmethod
    def get_language_detection_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
    ) -> PromptType:
        """Return a prompt that elicits a single language token from Whisper.

        Feed only ``<|startoftranscript|>`` as the decoder input so the model
        predicts the most likely language token (e.g. ``<|de|>``).
        """
        return ExplicitEncoderDecoderPrompt(
            encoder_prompt=TextPrompt(
                prompt="",
                multi_modal_data={"audio": (audio, stt_config.sample_rate)},
            ),
            decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
        )

    @classmethod
    def parse_language_detection_output(
        cls,
        token_ids: list[int],
        tokenizer: object,
    ) -> str | None:
        """Parse the language token predicted by Whisper.

        Decodes the first token ID and extracts the language code from the
        ``<|xx|>`` format. Expects a valid language token from constrained generation.
        """

        decoded = tokenizer.decode(
            [token_ids[0]],
            skip_special_tokens=False,
        )
        # Whisper language tokens have the form <|xx|>
        assert decoded.startswith("<|") and decoded.endswith("|>")
        lang_code = decoded[2:-2]
        assert lang_code in cls.supported_languages
        return lang_code

911
    @classmethod
912
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
913
914
915
916
917
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

918
    @classmethod
919
    def get_speech_to_text_config(
920
        cls, model_config: ModelConfig, task_type: str
921
    ) -> SpeechToTextConfig:
922
        processor = cached_processor_from_config(model_config)
923
924
925
926
927
928
929

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
930
931
932
933
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
934
        model_config: ModelConfig,
935
    ) -> int | None:
936
        processor = cached_processor_from_config(model_config)
937
938
939
940
941
942
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
943
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
944

945
946
947
948
949
950
951
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

952
953
954
955
956
957
        with self._mark_composite_model(
            vllm_config,
            language_targets=WhisperDecoder,
            tower_targets={"audio": WhisperEncoder},
        ):
            self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
958

959
960
961
962
963
964
965
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
966
        logit_scale = getattr(config, "logit_scale", 1.0)
967
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
968
969
970
971
972

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
973
        encoder_outputs: list[torch.Tensor] | None = None,
974
975
        **kwargs,
    ) -> torch.Tensor:
976
977
        if encoder_outputs is None:
            encoder_outputs = []
978
979
980
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
981
            encoder_outputs=encoder_outputs,
982
983
984
        )
        return decoder_outputs

985
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
986
        # Required as part of SupportsMultiModal interface.
987
        audio_input = self._parse_and_validate_audio_input(**kwargs)
988
989
990
991
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
992

993
    def embed_input_ids(
994
995
        self,
        input_ids: torch.Tensor,
996
        multimodal_embeddings: MultiModalEmbeddings | None = None,
997
        *,
998
        is_multimodal: torch.Tensor | None = None,
999
        handle_oov_mm_token: bool = False,
1000
    ) -> torch.Tensor:
1001
1002
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
1003
        return self.model.decoder.embed_input_ids(input_ids)
1004

1005
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
1006
1007
1008
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
1009
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
1010
1011
1012

        return WhisperAudioInputs(input_features=input_features)

1013
1014
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
1015
1016
        return logits

1017
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1018
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
1019

1020
        # add fake zeros bias for k_proj to state_dict
1021
        weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
1022
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1023
1024
1025


def _create_fake_bias_for_k_proj(
1026
    weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
1027
) -> Iterable[tuple[str, torch.Tensor]]:
1028
    """
1029
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
1030
1031
1032
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
1033
        yield name, weight
1034
        if name.endswith(fake_bias_key_name):
1035
1036
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
1037
            yield bias_name, bias