"tests/vscode:/vscode.git/clone" did not exist on "774d0c014b8699d244ba2889d872591ca535b80f"
whisper.py 33.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Patrick von Platen's avatar
Patrick von Platen committed
4
import enum
5
import math
6
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
7
from contextlib import nullcontext
8
from typing import Annotated, Literal, cast
9

10
import numpy as np
11
12
import torch
from torch import nn
13
14
15
16
17
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
18
19
from transformers.models.whisper.modeling_whisper import sinusoids

Patrick von Platen's avatar
Patrick von Platen committed
20
from vllm.attention.layer import Attention
21
from vllm.compilation.decorators import support_torch_compile
22
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
23
from vllm.config.multimodal import BaseDummyOptions
24
from vllm.distributed import get_tensor_model_parallel_world_size
25
from vllm.inputs.data import PromptType
26
27
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
28
29
from vllm.model_executor.layers.attention.cross_attention import CrossAttention
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
30
31
32
33
34
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Patrick von Platen's avatar
Patrick von Platen committed
39
40
41
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
)
42
from vllm.multimodal import MULTIMODAL_REGISTRY
43
44
45
46
47
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
48
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
49
from vllm.multimodal.processing import (
50
    BaseDummyInputsBuilder,
51
52
53
54
55
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
56
from vllm.transformers_utils.processor import cached_processor_from_config
57
from vllm.utils.jsontree import json_map_leaves
58
from vllm.utils.tensor_schema import TensorSchema, TensorShape
59
from vllm.utils.torch_utils import set_default_torch_dtype
60
61
62
from vllm.v1.attention.backend import (
    AttentionType,
)
63

64
65
66
67
68
69
70
71
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
72
73
74

logger = init_logger(__name__)

Patrick von Platen's avatar
Patrick von Platen committed
75
76
77

class WhisperPosEmbedType(enum.Enum):
    SINUSOIDAL = "sinusoidal"
78
    ROPE = "rope"
Patrick von Platen's avatar
Patrick von Platen committed
79
    LEARNED = "learned"
80

81

82
83
84
85
86
87
88
89
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

90
    input_features: Annotated[
91
        list[torch.Tensor] | None,
92
93
        TensorShape("b", "nmb", "t"),
    ]
94
95


96
class WhisperEncoderAttention(MMEncoderAttention):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


124
class WhisperPositionalEmbedding(nn.Embedding):
125
    def __init__(self, num_positions: int, embedding_dim: int):
126
127
128
129
130
131
132
133
134
135
136
137
138
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
Patrick von Platen's avatar
Patrick von Platen committed
139
        per_layer_sliding_window: int | None = None,
140
141
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
167
168
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
169
170
171
172
173
174
175
176
177
178
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
179
180
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
185
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
186
187
188
189
190
191
192
193
194
195
196
197
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
198
            self.attn = Attention(
Patrick von Platen's avatar
Patrick von Platen committed
199
200
201
202
203
204
205
206
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
Patrick von Platen's avatar
Patrick von Platen committed
207
                per_layer_sliding_window=per_layer_sliding_window,
Patrick von Platen's avatar
Patrick von Platen committed
208
            )
209
210
211
212
213

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
214
        quant_config: QuantizationConfig | None = None,
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

234
        attn_output = self.attn(q, k, v)
235
236
237
238
239
240
241
242
243
244
245
246

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
247
248
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
249
250
251
252
253
254
255
256
257
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
258
            attn_type=AttentionType.ENCODER_DECODER,
259
260
261
262
263
264
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
265
        quant_config: QuantizationConfig | None = None,
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
        self.kv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=0,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
288
        encoder_hidden_states: torch.Tensor | None,
289
290
291
292
293
294
295
296
297
298
299
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

300
        attn_output = self.attn(q, k, v)
301
302
303
304
305
306
307
308
309
310
311
312

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
313
        quant_config: QuantizationConfig | None = None,
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
340
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
341
342
        super().__init__()
        config = vllm_config.model_config.hf_config
Patrick von Platen's avatar
Patrick von Platen committed
343
        sliding_window = getattr(config, "sliding_window", None)
344
345
346
347
348
349
350
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
351
            attn_type=AttentionType.ENCODER,
Patrick von Platen's avatar
Patrick von Platen committed
352
            per_layer_sliding_window=sliding_window,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
373
        hidden_states = self.self_attn(hidden_states=hidden_states)
374
375
376
377
378
379
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

380
        hidden_states = cast_overflow_tensors(hidden_states)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
421
        encoder_hidden_states: torch.Tensor | None,
422
423
424
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
425
        hidden_states = self.self_attn(hidden_states=hidden_states)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
445
446
447
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
448
449
450
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
Patrick von Platen's avatar
Patrick von Platen committed
451
452
453
454

        self.pos_embed_type = WhisperPosEmbedType(
            getattr(config, "pos_embed", "sinusoidal")
        )
455
456
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
457
458
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

459
460
        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
461
462

        self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
463
464
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
465
466
467
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
468
469
470
471
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

472
        if self.pos_embed_type not in (
Patrick von Platen's avatar
Patrick von Platen committed
473
474
            WhisperPosEmbedType.SINUSOIDAL,
            WhisperPosEmbedType.LEARNED,
Patrick von Platen's avatar
Patrick von Platen committed
475
        ):
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
            raise ValueError(
                "Only sinusoidal or learned position embeddings are supported "
                f"for non-causal models, but got {self.pos_embed_type}"
            )

        maybe_fp32_init_ctx = (
            set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
        )

        with (
            torch.no_grad(),
            maybe_fp32_init_ctx,
        ):
            self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
            self.embed_positions.weight.copy_(
                sinusoids(*self.embed_positions.weight.shape)
492
            )
493

494
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
495
496
        self, input_features: torch.Tensor | list[torch.Tensor]
    ) -> torch.Tensor:
497
        hidden_states = []
498
        input_is_batched = False
499
500
501
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
502

503
504
505
506
            embeds = embeds.transpose(-1, -2)
            embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
                embeds.dtype
            )
Patrick von Platen's avatar
Patrick von Platen committed
507

508
            hidden_states.append(embeds)
509
510
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
511
        if input_is_batched:
512
513
514
515
            # Models using WhisperEncoder may handle batching internally.
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
516

517
518
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
519
520
521
522
523

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


524
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
525
526
527
528
529
530
531
532
class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
533
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
534

535
536
537
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
538
        self.embed_positions = WhisperPositionalEmbedding(
539
540
            self.max_target_positions, config.d_model
        )
541
542
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
543
544
545
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
546
547
548
549
550
551
552
553
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
554
        encoder_hidden_states: torch.Tensor | None,
555
    ):
556
        inputs_embeds = self.embed_input_ids(input_ids)
557
558
559
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

560
        for decoder_layer in self.layers:
561
562
563
564
565
566
567
568
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

569
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
570
571
572
573
574
575
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
576
577
578
579
580
581
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
582
583
584

    def forward(
        self,
585
        input_ids: torch.Tensor | None,
586
        positions: torch.Tensor,
587
        encoder_outputs: list[torch.Tensor],
588
    ) -> torch.Tensor:
589
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
590
591
592
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
593
            encoder_hidden_states=enc_states,
594
595
596
597
598
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
599
600
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
601
602
        if input_features is None:
            return None
603
        return self.encoder(input_features)
604

605
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
606
607
608
609
610
611
612
613
614
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
615
        loaded_params: set[str] = set()
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
635
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
636
637
638
639
640
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


641
642
643
644
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

645
646
647
648
    @property
    def skip_prompt_length_check(self) -> bool:
        return True  # Because the encoder prompt is padded

649
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
650
651
        return {"audio": 1}

652
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
653
        hf_processor = self.get_hf_processor(**kwargs)
654
655
656
657
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

658
659
660
661
    def get_target_channels(self) -> int:
        """Return target audio channels for Whisper models (mono)."""
        return 1

662
    def get_num_audio_tokens(self) -> int:
663
664
665
666
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
667
668
669
670
671
672
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
673
674
675
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
676
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
677
    ) -> MultiModalDataDict:
678
679
680
681
682
683
        feature_extractor = self.info.get_feature_extractor()

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

684
685
        audio_overrides = mm_options.get("audio") if mm_options else None

686
        return {
687
688
689
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
690
691
692
        }


693
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
694
695
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
696
697
698
699
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.info.get_target_channels(),
        )
700
701
702

    def create_encoder_prompt(
        self,
703
        prompt: str | list[int],
704
        mm_data: MultiModalDataDict,
705
    ) -> str | list[int]:
706
707
708
709
710
711
712
713
714
715
716
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
717
        tok_kwargs: Mapping[str, object],
718
719
    ) -> BatchFeature:
        if mm_data:
720
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
721
722
723
724
725
726
727
728
729
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
730
            tok_kwargs=tok_kwargs,
731
732
733
734
735
736
737
738
739
740
741
742
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

743
    def _get_prompt_updates(
744
745
746
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
747
        out_mm_kwargs: MultiModalKwargsItems,
748
    ) -> Sequence[PromptUpdate]:
749
        num_tokens = self.info.get_num_audio_tokens()
750
751
752
753
754
755
756
757
758
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


759
760
761
762
763
764
765
766
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
767
768
769
770
771
772
773
774
775
    packed_modules_mapping = {
        "self_attn.qkv_proj": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
        ],
        "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
    }

776
777
778
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
779

780
781
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
782
    supports_segment_timestamp = True
783
    supported_languages = ISO639_1_SUPPORTED_LANGS
784

785
    @classmethod
786
    def validate_language(cls, language: str | None) -> str | None:
787
788
789
790
        if language is None:
            # TODO language should be optional and can be guessed.
            # For now we default to en. See
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
791
            logger.warning(
792
793
                "Defaulting to language='en'. If you wish to transcribe "
                "audio in a different language, pass the `language` field "
794
795
                "in the TranscriptionRequest."
            )
796
797
            language = "en"
        return super().validate_language(language)
798
799

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
800
    def get_generation_prompt(
801
802
        cls,
        audio: np.ndarray,
803
        model_config: ModelConfig,  # not needed here
804
        stt_config: SpeechToTextConfig,
805
        language: str | None,
806
807
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
808
        to_language: str | None,
809
    ) -> PromptType:
810
811
        if language is None:
            raise ValueError(
812
813
                "Language must be specified when creating the Whisper prompt"
            )
814
815
816
817
818
819
820
821
        prompt = {
            "encoder_prompt": {
                # Whisper does not support encoder prompt.
                "prompt": "",
                "multi_modal_data": {
                    "audio": (audio, stt_config.sample_rate),
                },
            },
822
823
824
825
826
            "decoder_prompt": (
                (f"<|prev|>{request_prompt}" if request_prompt else "")
                + f"<|startoftranscript|><|{language}|>"
                + f"<|{task_type}|><|notimestamps|>"
            ),
827
828
        }
        return cast(PromptType, prompt)
829
830

    @classmethod
831
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
832
833
834
835
836
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

837
    @classmethod
838
    def get_speech_to_text_config(
839
        cls, model_config: ModelConfig, task_type: str
840
    ) -> SpeechToTextConfig:
841
        processor = cached_processor_from_config(model_config)
842
843
844
845
846
847
848

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
849
850
851
852
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
853
        model_config: ModelConfig,
854
    ) -> int | None:
855
        processor = cached_processor_from_config(model_config)
856
857
858
859
860
861
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
862
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
863

864
865
866
867
868
869
870
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

871
872
873
874
875
876
        with self._mark_composite_model(
            vllm_config,
            language_targets=WhisperDecoder,
            tower_targets={"audio": WhisperEncoder},
        ):
            self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
877

878
879
880
881
882
883
884
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
885
        logit_scale = getattr(config, "logit_scale", 1.0)
886
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
887
888
889
890
891

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
892
        encoder_outputs: list[torch.Tensor] | None = None,
893
894
        **kwargs,
    ) -> torch.Tensor:
895
896
        if encoder_outputs is None:
            encoder_outputs = []
897
898
899
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
900
            encoder_outputs=encoder_outputs,
901
902
903
        )
        return decoder_outputs

904
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
905
        # Required as part of SupportsMultiModal interface.
906
        audio_input = self._parse_and_validate_audio_input(**kwargs)
907
908
909
910
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
911

912
    def embed_input_ids(
913
914
        self,
        input_ids: torch.Tensor,
915
        multimodal_embeddings: MultiModalEmbeddings | None = None,
916
        *,
917
        is_multimodal: torch.Tensor | None = None,
918
        handle_oov_mm_token: bool = False,
919
    ) -> torch.Tensor:
920
921
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
922
        return self.model.decoder.embed_input_ids(input_ids)
923

924
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
925
926
927
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
928
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
929
930
931

        return WhisperAudioInputs(input_features=input_features)

932
933
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
934
935
        return logits

936
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
937
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
938

939
        # add fake zeros bias for k_proj to state_dict
940
        weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
941
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
942
943
944


def _create_fake_bias_for_k_proj(
945
    weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
946
) -> Iterable[tuple[str, torch.Tensor]]:
947
    """
948
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
949
950
951
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
952
        if name.endswith(fake_bias_key_name):
953
954
955
956
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
            yield from [(name, weight), (bias_name, bias)]
        yield name, weight