whisper.py 34.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Patrick von Platen's avatar
Patrick von Platen committed
4
import enum
5
import math
6
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
7
from contextlib import nullcontext
Patrick von Platen's avatar
Patrick von Platen committed
8
from functools import partial
9
from typing import Annotated, Literal, cast
10

11
import numpy as np
12
13
import torch
from torch import nn
14
15
16
17
18
from transformers import (
    BatchFeature,
    WhisperConfig,
    WhisperFeatureExtractor,
)
19
20
from transformers.models.whisper.modeling_whisper import sinusoids

Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
from vllm.attention.backends.abstract import (
    AttentionType,
)
from vllm.attention.layer import Attention
25
from vllm.attention.layers.cross_attention import CrossAttention
26
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
27
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
28
from vllm.config.multimodal import BaseDummyOptions
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
from vllm.inputs.data import PromptType
31
32
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
36
37
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
41
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Patrick von Platen's avatar
Patrick von Platen committed
42
43
44
45
46
from vllm.model_executor.models.whisper_utils import (
    ISO639_1_SUPPORTED_LANGS,
    WhisperAttentionWithBlockPooling,
    WhisperCausalConv1d,
)
47
from vllm.multimodal import MULTIMODAL_REGISTRY
48
49
50
51
52
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
53
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
54
55
56
57
58
59
from vllm.multimodal.processing import (
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
)
60
from vllm.multimodal.profiling import BaseDummyInputsBuilder
61
from vllm.transformers_utils.processor import cached_processor_from_config
62
from vllm.utils.jsontree import json_map_leaves
63
from vllm.utils.tensor_schema import TensorSchema, TensorShape
64
from vllm.utils.torch_utils import set_default_torch_dtype
65

66
67
68
69
70
71
72
73
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    cast_overflow_tensors,
    make_layers,
    maybe_prefix,
)
74
75
76

logger = init_logger(__name__)

Patrick von Platen's avatar
Patrick von Platen committed
77
78
79
80
81

class WhisperPosEmbedType(enum.Enum):
    SINUSOIDAL = "sinusoidal"
    NOPE = "nope"
    LEARNED = "learned"
82

83

84
85
86
87
88
89
90
91
class WhisperAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames (M)
    """

92
    input_features: Annotated[
93
        list[torch.Tensor] | None,
94
95
        TensorShape("b", "nmb", "t"),
    ]
96
97


98
class WhisperEncoderAttention(MMEncoderAttention):
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    """Multi-headed attention for Whisper encoder with 2D tensor support."""

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input shape: batch_size x seq_len x hidden_size
                     or seq_len x hidden_size
        """
        is_2d = query.dim() == 2
        if is_2d:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)

        # Call the parent forward method
        out = super().forward(query, key, value)

        if is_2d:
            out = out.squeeze(0)

        return out


126
class WhisperPositionalEmbedding(nn.Embedding):
127
    def __init__(self, num_positions: int, embedding_dim: int):
128
129
130
131
132
133
134
135
136
137
138
139
140
        super().__init__(num_positions, embedding_dim)

    def forward(self, position_ids):
        return self.weight[position_ids]


class WhisperAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        attn_type: AttentionType = AttentionType.DECODER,
Patrick von Platen's avatar
Patrick von Platen committed
141
142
        per_layer_sliding_window: int | None = None,
        block_pool_size: int = 1,
143
144
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_heads >= tp_size:
            # Number of heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_heads % tp_size == 0
        else:
            # Number of heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_heads == 0
        self.num_kv_heads = max(1, self.total_num_heads // tp_size)
        self.head_dim = self.embed_dim // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.attn_type = attn_type

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
170
171
                f"{self.embed_dim} and `num_heads`: {num_heads})."
            )
172
173
174
175
176
177
178
179
180
181
        self.scaling = self.head_dim**-0.5

        self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
        self.out_proj = RowParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
182
183
        if attn_type == AttentionType.ENCODER:
            self.attn = WhisperEncoderAttention(
Patrick von Platen's avatar
Patrick von Platen committed
184
185
186
187
188
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
            )
189
190
191
192
193
194
195
196
197
198
199
200
        elif self.attn_type == AttentionType.ENCODER_DECODER:
            self.attn = CrossAttention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
            )
        else:  # AttentionType.DECODER (regular decoder self-attention)
Patrick von Platen's avatar
Patrick von Platen committed
201
202
203
204
205
206
207
208
            if block_pool_size > 1:
                attn_cls = partial(
                    WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
                )
            else:
                attn_cls = Attention

            self.attn = attn_cls(
Patrick von Platen's avatar
Patrick von Platen committed
209
210
211
212
213
214
215
216
                self.num_heads,
                self.head_dim,
                self.scaling,
                num_kv_heads=self.num_kv_heads,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
                attn_type=self.attn_type,
Patrick von Platen's avatar
Patrick von Platen committed
217
                per_layer_sliding_window=per_layer_sliding_window,
Patrick von Platen's avatar
Patrick von Platen committed
218
            )
219
220
221
222
223

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
224
        quant_config: QuantizationConfig | None = None,
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        prefix: str = "",
    ) -> None:
        self.qkv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

244
        attn_output = self.attn(q, k, v)
245
246
247
248
249
250
251
252
253
254
255
256

        output, _ = self.out_proj(attn_output)

        return output


class WhisperCrossAttention(WhisperAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
257
258
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
259
260
261
262
263
264
265
266
267
        prefix: str = "",
    ):
        super().__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
268
            attn_type=AttentionType.ENCODER_DECODER,
269
270
271
272
273
274
        )

    def _init_qkv(
        self,
        embed_dim: int,
        bias: bool = True,
275
        quant_config: QuantizationConfig | None = None,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        prefix: str = "",
    ) -> None:
        self.q_proj = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
        self.kv_proj = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.head_dim,
            total_num_heads=0,
            total_num_kv_heads=self.total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
298
        encoder_hidden_states: torch.Tensor | None,
299
300
301
302
303
304
305
306
307
308
309
    ):
        q, _ = self.q_proj(hidden_states)

        # Encoder hidden states are only computed once during prefill phase.
        # Afterwards, the keys and values should be available in the kv-cache.
        if encoder_hidden_states is not None:
            kv, _ = self.kv_proj(encoder_hidden_states)
            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
        else:
            k = v = None

310
        attn_output = self.attn(q, k, v)
311
312
313
314
315
316
317
318
319
320
321
322

        output, _ = self.out_proj(attn_output)

        return output


class WhisperMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        act_fn: str,
323
        quant_config: QuantizationConfig | None = None,
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        prefix: str = "",
    ):
        super().__init__()

        self.activation_fn = get_act_fn(act_fn)
        self.fc1 = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=ffn_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            input_size=ffn_dim,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class WhisperEncoderLayer(nn.Module):
350
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
351
352
        super().__init__()
        config = vllm_config.model_config.hf_config
Patrick von Platen's avatar
Patrick von Platen committed
353
354
355
        is_causal = getattr(config, "is_causal", False)
        sliding_window = getattr(config, "sliding_window", None)
        block_pool_size = getattr(config, "block_pool_size", 1)
356
357
358
359
360
361
362
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.embed_dim = config.d_model
        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
Patrick von Platen's avatar
Patrick von Platen committed
363
364
365
            attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
            block_pool_size=block_pool_size,
            per_layer_sliding_window=sliding_window,
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.encoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
386
        hidden_states = self.self_attn(hidden_states=hidden_states)
387
388
389
390
391
392
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

393
        hidden_states = cast_overflow_tensors(hidden_states)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        return hidden_states


class WhisperDecoderLayer(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.self_attn = WhisperAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            attn_type=AttentionType.DECODER,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.encoder_attn = WhisperCrossAttention(
            embed_dim=config.d_model,
            num_heads=config.decoder_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder_attn",
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
        self.mlp = WhisperMLP(
            embed_dim=config.d_model,
            ffn_dim=config.decoder_ffn_dim,
            act_fn=config.activation_function,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        hidden_states: torch.Tensor,
434
        encoder_hidden_states: torch.Tensor | None,
435
436
437
    ):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
438
        hidden_states = self.self_attn(hidden_states=hidden_states)
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)
        hidden_states = self.encoder_attn(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class WhisperEncoder(nn.Module):
458
459
460
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
461
462
463
        super().__init__()
        config = vllm_config.model_config.hf_config
        embed_dim = config.d_model
Patrick von Platen's avatar
Patrick von Platen committed
464
465
466
467

        self.pos_embed_type = WhisperPosEmbedType(
            getattr(config, "pos_embed", "sinusoidal")
        )
468
469
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
470
471
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

472
473
474
475
        self.is_causal = getattr(config, "is_causal", False)
        Conv1d = (
            WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
        )
Patrick von Platen's avatar
Patrick von Platen committed
476
477
478
479
480

        self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
        self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)

        self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
481
482
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.encoder_layers,
483
484
485
            lambda prefix: WhisperEncoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
486
487
488
489
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

490
        if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
Patrick von Platen's avatar
Patrick von Platen committed
491
492
493
494
495
496
497
            raise ValueError(
                "Only NOPE position embeddings are supported "
                f"for causal models, but got {self.pos_embed_type}"
            )
        elif self.pos_embed_type in (
            WhisperPosEmbedType.SINUSOIDAL,
            WhisperPosEmbedType.LEARNED,
Patrick von Platen's avatar
Patrick von Platen committed
498
        ):
Patrick von Platen's avatar
Patrick von Platen committed
499
500
501
502
            maybe_fp32_init_ctx = (
                set_default_torch_dtype(torch.float32)
                if init_in_fp32
                else nullcontext()
503
            )
504

Patrick von Platen's avatar
Patrick von Platen committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
            with (
                torch.no_grad(),
                maybe_fp32_init_ctx,
            ):
                self.embed_positions = nn.Embedding(
                    self.max_source_positions, embed_dim
                )
                self.embed_positions.weight.copy_(
                    sinusoids(*self.embed_positions.weight.shape)
                )

    def forward_conv(
        self, input_features: torch.Tensor | list[torch.Tensor]
    ) -> torch.Tensor:
519
        hidden_states = []
520
        input_is_batched = False
521
522
523
        for features in input_features:
            embeds = nn.functional.gelu(self.conv1(features))
            embeds = nn.functional.gelu(self.conv2(embeds))
Patrick von Platen's avatar
Patrick von Platen committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537

            if self.pos_embed_type in (
                WhisperPosEmbedType.SINUSOIDAL,
                WhisperPosEmbedType.LEARNED,
            ):
                embeds = embeds.transpose(-1, -2)
                embeds = (
                    embeds + self.embed_positions.weight[: embeds.size(-2), :]
                ).to(embeds.dtype)
            elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
                embeds = embeds.transpose(-1, -2).to(embeds.dtype)
            else:
                raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")

538
            hidden_states.append(embeds)
539
540
            input_is_batched = embeds.ndim > 2
        # Input to MHA must be B x T x D
541
        if input_is_batched or self.is_causal:
542
            # Models using WhisperEncoder may handle batching internally.
543
544
545
            # If WhisperEncoder is causal, sequences
            # are not padded to have identical seq length (T)
            # => concat over feature dim
546
547
548
            hidden_states = torch.cat(hidden_states)
        else:
            hidden_states = torch.stack(hidden_states, dim=0)
549

Patrick von Platen's avatar
Patrick von Platen committed
550
551
552
        return hidden_states

    def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
553
554
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
555
556
557
558

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
559
560
561
562
    def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
        hidden_states = self.forward_conv(input_features)
        return self.forward_layers(hidden_states)

563
564
565
566
567
568
569
570
571

class WhisperDecoder(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_target_positions
        self.max_source_positions = config.max_source_positions
572
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
573

574
575
576
        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.d_model, self.padding_idx
        )
577
        self.embed_positions = WhisperPositionalEmbedding(
578
579
            self.max_target_positions, config.d_model
        )
580
581
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.decoder_layers,
582
583
584
            lambda prefix: WhisperDecoderLayer(
                vllm_config=vllm_config, prefix=f"{prefix}.layers"
            ),
585
586
587
588
589
590
591
592
            prefix=f"{prefix}.layers",
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids,
        positions: torch.Tensor,
593
        encoder_hidden_states: torch.Tensor | None,
594
    ):
595
        inputs_embeds = self.embed_input_ids(input_ids)
596
597
598
        positions = self.embed_positions(positions)
        hidden_states = inputs_embeds + positions

599
        for decoder_layer in self.layers:
600
601
602
603
604
605
606
607
            hidden_states = decoder_layer(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
            )

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

608
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
609
610
611
612
613
614
        return self.embed_tokens(input_ids)


class WhisperModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
615
616
617
618
619
620
        self.encoder = WhisperEncoder(
            vllm_config=vllm_config, prefix=f"{prefix}.encoder"
        )
        self.decoder = WhisperDecoder(
            vllm_config=vllm_config, prefix=f"{prefix}.decoder"
        )
621
622
623

    def forward(
        self,
624
        input_ids: torch.Tensor | None,
625
        positions: torch.Tensor,
626
        encoder_outputs: list[torch.Tensor],
627
    ) -> torch.Tensor:
628
        enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
629
630
631
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
632
            encoder_hidden_states=enc_states,
633
634
635
636
637
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
638
639
        input_features: torch.Tensor | list[torch.Tensor] | None,
    ) -> torch.Tensor | None:
640
641
        if input_features is None:
            return None
642
        return self.encoder(input_features)
643

644
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
645
646
647
648
649
650
651
652
653
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
            (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
654
        loaded_params: set[str] = set()
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
674
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
675
676
677
678
679
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


680
681
682
683
class WhisperProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> WhisperConfig:
        return self.ctx.get_hf_config(WhisperConfig)

684
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
685
686
        return {"audio": 1}

687
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
688
        hf_processor = self.get_hf_processor(**kwargs)
689
690
691
692
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

693
    def get_num_audio_tokens(self) -> int:
694
695
696
697
        return self.get_hf_config().max_source_positions


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
698
699
700
701
702
703
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|startoftranscript|>" * num_audios

    def get_dummy_mm_data(
704
705
706
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
707
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
708
    ) -> MultiModalDataDict:
709
710
711
712
713
714
        feature_extractor = self.info.get_feature_extractor()

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

715
716
        audio_overrides = mm_options.get("audio") if mm_options else None

717
        return {
718
719
720
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
721
722
723
        }


724
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
725
726
727
728
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

729
730
731
732
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return True

733
734
    def create_encoder_prompt(
        self,
735
        prompt: str | list[int],
736
        mm_data: MultiModalDataDict,
737
    ) -> str | list[int]:
738
739
740
741
742
743
744
745
746
747
748
        # Strictly speaking, whisper encoder only accept audio features.
        # We create a dummy encoder prompt here which will be padded to
        # num_audio_tokens. So that we can create dummy data from this
        # for encoder profiling.
        return [0]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
749
        tok_kwargs: Mapping[str, object],
750
751
    ) -> BatchFeature:
        if mm_data:
752
            feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
753
754
755
756
757
758
759
760
761
            mm_data = dict(audio=mm_data.pop("audios"))
            mm_kwargs = dict(
                **mm_kwargs,
                sampling_rate=feature_extractor.sampling_rate,
            )
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
762
            tok_kwargs=tok_kwargs,
763
764
765
766
767
768
769
770
771
772
773
774
        )
        if "labels" in processed_outputs:
            processed_outputs["input_ids"] = processed_outputs.pop("labels")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(input_features=MultiModalFieldConfig.batched("audio"))

775
    def _get_prompt_updates(
776
777
778
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
779
        out_mm_kwargs: MultiModalKwargsItems,
780
    ) -> Sequence[PromptUpdate]:
781
        num_tokens = self.info.get_num_audio_tokens()
782
783
784
785
786
787
788
789
790
        return [
            PromptReplacement(
                modality="audio",
                target=[0],
                replacement=[0] * num_tokens,
            )
        ]


791
792
793
794
795
796
797
798
@MULTIMODAL_REGISTRY.register_processor(
    WhisperMultiModalProcessor,
    info=WhisperProcessingInfo,
    dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
799
800
801
802
803
804
805
806
807
    packed_modules_mapping = {
        "self_attn.qkv_proj": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
        ],
        "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
    }

808
809
810
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
    )
811

812
813
    # Whisper only supports audio-conditioned generation.
    supports_transcription_only = True
814
    supports_segment_timestamp = True
815
    supported_languages = ISO639_1_SUPPORTED_LANGS
816

817
    @classmethod
818
    def validate_language(cls, language: str | None) -> str | None:
819
820
821
822
        if language is None:
            # TODO language should be optional and can be guessed.
            # For now we default to en. See
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
823
            logger.warning(
824
825
                "Defaulting to language='en'. If you wish to transcribe "
                "audio in a different language, pass the `language` field "
826
827
                "in the TranscriptionRequest."
            )
828
829
            language = "en"
        return super().validate_language(language)
830
831

    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
832
    def get_generation_prompt(
833
834
        cls,
        audio: np.ndarray,
835
        model_config: ModelConfig,  # not needed here
836
        stt_config: SpeechToTextConfig,
837
        language: str | None,
838
839
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
840
        to_language: str | None,
841
    ) -> PromptType:
842
843
        if language is None:
            raise ValueError(
844
845
                "Language must be specified when creating the Whisper prompt"
            )
846
847
848
849
850
851
852
853
        prompt = {
            "encoder_prompt": {
                # Whisper does not support encoder prompt.
                "prompt": "",
                "multi_modal_data": {
                    "audio": (audio, stt_config.sample_rate),
                },
            },
854
855
856
857
858
            "decoder_prompt": (
                (f"<|prev|>{request_prompt}" if request_prompt else "")
                + f"<|startoftranscript|><|{language}|>"
                + f"<|{task_type}|><|notimestamps|>"
            ),
859
860
        }
        return cast(PromptType, prompt)
861
862

    @classmethod
863
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
864
865
866
867
868
        if modality.startswith("audio"):
            return None

        raise ValueError("Only audio modality is supported")

869
    @classmethod
870
    def get_speech_to_text_config(
871
        cls, model_config: ModelConfig, task_type: str
872
    ) -> SpeechToTextConfig:
873
        processor = cached_processor_from_config(model_config)
874
875
876
877
878
879
880

        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
        )

    @classmethod
881
882
883
884
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
885
        model_config: ModelConfig,
886
    ) -> int | None:
887
        processor = cached_processor_from_config(model_config)
888
889
890
891
892
893
        hop_length = processor.feature_extractor.hop_length
        assert hop_length is not None
        # NOTE(NickLucche) user can't pass encoder
        # prompts directly at least not to Whisper.
        # One indicator of the encoder amount of processing
        # is the log-mel spectogram length.
894
        return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
895

896
897
898
899
900
901
902
903
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

        self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
904

905
906
907
908
909
910
911
        self.proj_out = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
912
        logit_scale = getattr(config, "logit_scale", 1.0)
913
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
914
915
916
917
918

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
919
        encoder_outputs: list[torch.Tensor] | None = None,
920
921
        **kwargs,
    ) -> torch.Tensor:
922
923
        if encoder_outputs is None:
            encoder_outputs = []
924
925
926
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
927
            encoder_outputs=encoder_outputs,
928
929
930
        )
        return decoder_outputs

931
932
933
    def get_language_model(self) -> torch.nn.Module:
        return self.model.decoder

934
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
935
        # Required as part of SupportsMultiModal interface.
936
        audio_input = self._parse_and_validate_audio_input(**kwargs)
937
938
939
940
        # Split concatenated encoder outputs into one tensor per audio input
        enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
        # The assumption is we can only process whole mm items (audios)
        return enc_output.unbind(dim=0)
941

942
    def embed_input_ids(
943
944
        self,
        input_ids: torch.Tensor,
945
        multimodal_embeddings: MultiModalEmbeddings | None = None,
946
        *,
947
        is_multimodal: torch.Tensor | None = None,
948
        handle_oov_mm_token: bool = False,
949
    ) -> torch.Tensor:
950
951
        # This method just returns the decoder sequence embeddings since
        # Whisper does not have encoder text tokens.
952
        return self.model.decoder.embed_input_ids(input_ids)
953

954
    def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
955
956
957
        input_features = kwargs.pop("input_features", None)

        if input_features is not None:
958
            input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
959
960
961

        return WhisperAudioInputs(input_features=input_features)

962
963
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
964
965
        return logits

966
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
967
        loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
968

969
970
        # add fake zeros bias for k_proj to state_dict
        weights = _create_fake_bias_for_k_proj(weights)
971
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
972
973
974


def _create_fake_bias_for_k_proj(
975
    weights: Iterable[tuple[str, torch.Tensor]],
976
) -> Iterable[tuple[str, torch.Tensor]]:
977
    """
978
    Create full zeros bias for k_proj weight in self-attn and x-attn layers.
979
980
981
    So that the bias for k_proj in qkv_proj can be initialized with zeros.
    """
    for name, weight in weights:
982
        if name.endswith(".k_proj.weight"):
983
984
985
986
            bias = torch.zeros(weight.size(0))
            bias_name = name.replace("weight", "bias")
            yield from [(name, weight), (bias_name, bias)]
        yield name, weight