voxtral.py 31.6 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable, Mapping, Sequence
6
from functools import cached_property, partial
Patrick von Platen's avatar
Patrick von Platen committed
7
from math import ceil
8
from typing import Literal, cast
Patrick von Platen's avatar
Patrick von Platen committed
9
10
11
12
13
14

import numpy as np
import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
15
16
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
Patrick von Platen's avatar
Patrick von Platen committed
17
18
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
19
20
21
22
from mistral_common.tokens.tokenizers.audio import (
    Audio,
    AudioEncoder,
)
23
from transformers import BatchFeature, TensorType, WhisperConfig
Patrick von Platen's avatar
Patrick von Platen committed
24
25
from transformers.tokenization_utils_base import TextInput

26
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
27
from vllm.config.multimodal import BaseDummyOptions
Patrick von Platen's avatar
Patrick von Platen committed
28
29
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
30
from vllm.model_executor.layers.quantization import QuantizationConfig
Patrick von Platen's avatar
Patrick von Platen committed
31
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
35
36
37
38
from vllm.model_executor.models.whisper import (
    WhisperEncoder,
    _create_fake_bias_for_k_proj,
)
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
Patrick von Platen's avatar
Patrick von Platen committed
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
41
42
43
44
45
46
47
48
49
50
51
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
52
53
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing.processor import (
54
55
56
57
58
59
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
Patrick von Platen's avatar
Patrick von Platen committed
60
from vllm.sequence import IntermediateTensors
61
62
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
Patrick von Platen's avatar
Patrick von Platen committed
63

64
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
65
from .utils import init_vllm_registered_model, maybe_prefix
Patrick von Platen's avatar
Patrick von Platen committed
66
67
68

logger = init_logger(__name__)

69
70
71
72
73
74
75
76
77
78
79
80
ISO639_1_SUPPORTED_LANGS = {
    "ar": "Arabic",
    "nl": "Dutch",
    "en": "English",
    "fr": "French",
    "de": "German",
    "hi": "Hindi",
    "it": "Italian",
    "pt": "Portuguese",
    "es": "Spanish",
}

Patrick von Platen's avatar
Patrick von Platen committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

class VoxtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
    :class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    @cached_property
    def _audio_processor(self) -> AudioEncoder:
        audio_encoder = self.tokenizer.instruct.audio_encoder
        assert isinstance(audio_encoder, AudioEncoder)
        return audio_encoder

    @cached_property
    def audio_token_id(self) -> int:
        return self._audio_processor.special_ids.audio

    @cached_property
    def begin_audio_token_id(self) -> int:
        return self._audio_processor.special_ids.begin_audio

    @cached_property
    def sampling_rate(self) -> int:
        return self._audio_processor.audio_config.sampling_rate

    @cached_property
    def frame_rate(self) -> float:
        return self._audio_processor.audio_config.frame_rate

    def get_num_audio_tokens(
        self,
        audio_length: int,
    ) -> int:
Patrick von Platen's avatar
Patrick von Platen committed
118
        return ceil(audio_length / (self.sampling_rate // self.frame_rate))
Patrick von Platen's avatar
Patrick von Platen committed
119
120
121

    def __call__(
        self,
122
123
124
        text: TextInput | list[TextInput] | None = None,
        audios: np.ndarray | list[np.ndarray] | None = None,
        return_tensors: str | TensorType | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if audios is None:
            audios = []
        if not isinstance(audios, list):
            audios = [audios]

        if not audios:
            input_ids = self.tokenizer(text).input_ids
            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
147
148
                "https://github.com/vllm-project/vllm/issues/8411."
            )
Patrick von Platen's avatar
Patrick von Platen committed
149
150
151
152
153
154
155

        audios_tokens = list[torch.Tensor]()
        audios_processed = list[torch.Tensor]()
        for audio in audios:
            assert isinstance(audio, np.ndarray)
            assert audio.ndim == 1

156
157
158
159
            if not self._audio_processor.audio_config.is_streaming:
                audio = self._audio_processor.pad(
                    audio, self.sampling_rate, is_online_streaming=False
                )
Patrick von Platen's avatar
Patrick von Platen committed
160

161
162
163
            audio_tokens = [self.begin_audio_token_id] + [
                self.audio_token_id
            ] * self.get_num_audio_tokens(len(audio))
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
167

            audios_tokens.append(torch.tensor(audio_tokens))
            audios_processed.append(torch.tensor(audio))

168
169
170
171
172
173
        return BatchFeature(
            {
                "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
                "audio_arrays": audios_processed,
            }
        )
Patrick von Platen's avatar
Patrick von Platen committed
174
175
176
177


class VoxtralProcessingInfo(BaseProcessingInfo):
    def get_tokenizer(self) -> MistralTokenizer:
178
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
Patrick von Platen's avatar
Patrick von Platen committed
179
180
181
182
183
184
185
186
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

    def get_hf_processor(self) -> VoxtralProcessorAdapter:
        return VoxtralProcessorAdapter(self.get_tokenizer())

187
188
189
190
191
192
    def get_data_parser(self):
        return MultiModalDataParser(
            target_sr=self.get_hf_processor().sampling_rate,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

193
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Patrick von Platen's avatar
Patrick von Platen committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        return {"audio": 5}  # Performance tends to degrade after 5

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"audio": self.get_max_audio_tokens()}

    def get_max_audio_tokens(self) -> int:
        return self.ctx.model_config.max_model_len

    def get_max_audio_array_len(self) -> int:
        processor = self.get_hf_processor()
        return self.get_max_audio_tokens() * int(
209
210
            processor.sampling_rate // processor.frame_rate
        )
Patrick von Platen's avatar
Patrick von Platen committed
211
212
213
214
215
216
217
218
219
220


class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
221
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
222
223
224
225
226
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)

        target_length = self.info.get_max_audio_array_len()

227
228
        audio_overrides = mm_options.get("audio") if mm_options else None

Patrick von Platen's avatar
Patrick von Platen committed
229
        return {
230
231
232
            "audio": self._get_dummy_audios(
                length=target_length, num_audios=num_audios, overrides=audio_overrides
            )
Patrick von Platen's avatar
Patrick von Platen committed
233
234
235
236
237
238
        }

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
239
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
240
241
242
243
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()

        dummy_text = self.get_dummy_text(mm_counts)
244
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
Patrick von Platen's avatar
Patrick von Platen committed
245
246
247
248
249
250
251
252
253
254
255
256
257
        dummy_audios = dummy_mm_data.get("audio", [])

        audio_chunks: list[AudioChunk] = []
        format = "wav"
        for audio in dummy_audios:
            audio_item = Audio(
                audio_array=audio,
                sampling_rate=self.info.get_hf_processor().sampling_rate,
                format=format,
            )
            chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
            audio_chunks.append(chunk)

258
259
260
261
262
        request = ChatCompletionRequest(
            messages=[
                UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
            ]
        )
Patrick von Platen's avatar
Patrick von Platen committed
263
264
265
266
267
268
269
270
271
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens
        # whixtral tokenizer adds padding to the audio
        # so we need to update the audio arrays
        dummy_mm_data["audio"] = [a.audio_array for a in res.audios]

        return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)


272
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
Patrick von Platen's avatar
Patrick von Platen committed
273
274
275
276
277
278
279
280
281
282
283
    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
284
        out_mm_kwargs: MultiModalKwargsItems,
Patrick von Platen's avatar
Patrick von Platen committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        audio_id = processor.audio_token_id

        def get_replacement(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio_len = audios.get_audio_length(item_idx)

            nb_audio_tokens = processor.get_num_audio_tokens(audio_len)

            return [audio_id] * nb_audio_tokens

        return [
            PromptReplacement(
                modality="audio",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
308
        prompt: str | list[int],
Patrick von Platen's avatar
Patrick von Platen committed
309
310
311
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
312
        mm_uuids: MultiModalUUIDDict | None = None,
313
314
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
        prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
Patrick von Platen's avatar
Patrick von Platen committed
315
316
317
318
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
319
            mm_uuids=mm_uuids,
Patrick von Platen's avatar
Patrick von Platen committed
320
321
322
        )

        # NOTE: The tokens are already inserted by the chat template
323
        return prompt_ids, mm_info, True
Patrick von Platen's avatar
Patrick von Platen committed
324
325


326
327
328
329
330
331
332
333
@MULTIMODAL_REGISTRY.register_processor(
    VoxtralMultiModalProcessor,
    info=VoxtralProcessingInfo,
    dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
334
    supported_languages = ISO639_1_SUPPORTED_LANGS
Patrick von Platen's avatar
Patrick von Platen committed
335

336
337
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
338
        "gate_up_proj": ["gate_proj", "up_proj"],
339
340
    }

Patrick von Platen's avatar
Patrick von Platen committed
341
342
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
343
        self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
Patrick von Platen's avatar
Patrick von Platen committed
344

345
346
347
348
        # update quant config to so that ignored module and target module names
        # match the vLLM model names
        if hasattr(vllm_config, "quant_config"):
            vllm_config.quant_config = self.maybe_update_quant_config(
349
350
                vllm_config.quant_config
            )
351

Patrick von Platen's avatar
Patrick von Platen committed
352
353
354
355
        config = vllm_config.model_config.hf_config
        self.config = config
        self.downsample_factor = self.config.audio_config.downsample_factor

356
357
358
359
360
361
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
Patrick von Platen's avatar
Patrick von Platen committed
362

363
364
365
366
367
368
369
370
371
        with self._mark_tower_model(vllm_config, "audio"):
            self.whisper_encoder = VoxtralEncoderModel(
                vllm_config.with_hf_config(config.audio_config),
                prefix=maybe_prefix(prefix, "whisper_encoder"),
            )
            self.audio_language_adapter = AudioLanguageAdapter(
                hidden_size=config.audio_config.d_model * self.downsample_factor,
                dim=config.text_config.hidden_size,
            )
Patrick von Platen's avatar
Patrick von Platen committed
372

373
374
375
376
377
378
379
380
    def get_mm_mapping(self) -> MultiModelKeys:
        """Get module prefix for multimodal models to filter LoRA modules."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="audio_language_adapter",
            tower_model=["whisper_encoder"],
        )

Patrick von Platen's avatar
Patrick von Platen committed
381
382
    def forward(
        self,
383
        input_ids: torch.Tensor | None,
Patrick von Platen's avatar
Patrick von Platen committed
384
        positions: torch.Tensor,
385
386
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
387
        **kwargs: object,
388
    ) -> torch.Tensor | IntermediateTensors:
Patrick von Platen's avatar
Patrick von Platen committed
389
390
391
        if intermediate_tensors is not None:
            inputs_embeds = None

392
393
394
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Patrick von Platen's avatar
Patrick von Platen committed
395
396
397

        return hidden_states

398
    def embed_multimodal(
Patrick von Platen's avatar
Patrick von Platen committed
399
        self, **kwargs
400
    ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
Patrick von Platen's avatar
Patrick von Platen committed
401
402
403
404
405
406
407
408
409
410
        audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
        if audio_inputs is None:
            return None

        audio_embeddings = self.whisper_encoder(audio_inputs)

        for i, audio_embedding in enumerate(audio_embeddings):
            seq_len, dim = audio_embedding.shape
            # Pad such that seq_len is divisible by downsample_factor
            target_seq_len = self.downsample_factor * math.ceil(
411
412
                seq_len / self.downsample_factor
            )
Patrick von Platen's avatar
Patrick von Platen committed
413
414
415
416
417
            audio_embedding = torch.nn.functional.pad(
                audio_embedding,
                (0, 0, 0, target_seq_len - seq_len),
            )
            audio_embeddings[i] = audio_embedding.reshape(
418
419
                target_seq_len // self.downsample_factor, dim * self.downsample_factor
            )
Patrick von Platen's avatar
Patrick von Platen committed
420
421
422

        # Concat, project and resplit
        audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
423
424
425
426
        audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
        audio_embeddings = torch.split(
            audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
        )
Patrick von Platen's avatar
Patrick von Platen committed
427
428
429
430

        return audio_embeddings

    def _parse_and_validate_audio_arrays(
431
        self, **kwargs: object
432
    ) -> list[torch.Tensor] | None:
Patrick von Platen's avatar
Patrick von Platen committed
433
434
435
436
437
        audio_arrays = kwargs.pop("audio_arrays", None)
        if audio_arrays is None:
            return None

        if not isinstance(audio_arrays, (torch.Tensor, list)):
438
439
440
            raise ValueError(
                f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
            )
Patrick von Platen's avatar
Patrick von Platen committed
441
442
443
444
445
446
447
448

        if isinstance(audio_arrays, torch.Tensor):
            audio_arrays = list(audio_arrays.unbind(0))
        return audio_arrays

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
449
    ) -> torch.Tensor | None:
450
        return self.language_model.compute_logits(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
451
452

    @classmethod
453
    def get_speech_to_text_config(
454
        cls, model_config: ModelConfig, task_type: str
455
    ) -> SpeechToTextConfig:
456
        tokenizer = cached_tokenizer_from_config(model_config)
Patrick von Platen's avatar
Patrick von Platen committed
457
458
459
460
461
462
463
464
465
466
467
468
        audio_config = tokenizer.instruct.audio_encoder.audio_config
        max_audio_clip_s = audio_config.chunk_length_s
        sample_rate = audio_config.sampling_rate
        return SpeechToTextConfig(
            max_audio_clip_s=max_audio_clip_s,
            sample_rate=sample_rate,
            # mistral_common and whisper encoder take care of chunking
            min_energy_split_window_size=None,
        )

    @classmethod
    # for speech-to-text transcription
469
470
471
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
472
        model_config: ModelConfig,
473
        stt_config: SpeechToTextConfig,
474
        language: str | None,
475
476
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
477
        to_language: str | None,
478
    ) -> PromptType:
479
        tokenizer = cached_tokenizer_from_config(model_config)
480
481
        audio = Audio(audio, int(stt_config.sample_rate), format="wav")  # lossless
        req = TranscriptionRequest(
482
            model=model_config.model,
483
484
485
            audio=RawAudio.from_audio(audio),
            language=language,
        )
Patrick von Platen's avatar
Patrick von Platen committed
486
487
488
489
490
491
492
493

        tokenized = tokenizer.instruct.encode_transcription(req)
        audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
        prompts_dict = {"multi_modal_data": {"audio": audio}}
        prompts_dict["prompt_token_ids"] = tokenized.tokens
        return cast(PromptType, prompts_dict)

    @classmethod
494
495
496
497
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
498
        model_config: ModelConfig,
499
    ) -> int | None:
Patrick von Platen's avatar
Patrick von Platen committed
500
        """
501
        Map from audio duration to number of audio tokens produced by the ASR
Patrick von Platen's avatar
Patrick von Platen committed
502
503
504
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
505
        tokenizer = cached_tokenizer_from_config(model_config)
Patrick von Platen's avatar
Patrick von Platen committed
506
507
        adapter = VoxtralProcessorAdapter(tokenizer)
        return adapter.get_num_audio_tokens(
508
509
            int(audio_duration_s * stt_config.sample_rate)
        )
Patrick von Platen's avatar
Patrick von Platen committed
510

511
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Patrick von Platen's avatar
Patrick von Platen committed
512
        remapping_rules = [
Patrick von Platen's avatar
Patrick von Platen committed
513
            (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
Patrick von Platen's avatar
Patrick von Platen committed
514
515
            (r"mm_whisper_embeddings\.(.*)", r"\1"),
            (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
516
517
518
519
520
521
522
523
            (
                r"audio_language_adapter\.0\.weight",
                r"audio_language_adapter.w_in.weight",
            ),
            (
                r"audio_language_adapter\.2\.weight",
                r"audio_language_adapter.w_out.weight",
            ),
Patrick von Platen's avatar
Patrick von Platen committed
524
525
526
        ]

        audio_params = dict(
527
528
529
530
531
532
            nn.ModuleDict(
                {
                    "audio_language_adapter": self.audio_language_adapter,
                }
            ).named_parameters()
        )
533
        weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
Patrick von Platen's avatar
Patrick von Platen committed
534
535
536
537
538
539

        loaded_weights = set()

        def llm_weights_generator():
            nonlocal loaded_weights
            for name, w in weights:
Patrick von Platen's avatar
Patrick von Platen committed
540
541
542
543
544
545
546
547
548
                is_encoder = False
                for k in [
                    "mm_whisper_embeddings",
                    "mm_streams_embeddings.embedding_module",
                ]:
                    is_encoder |= (
                        name.startswith(k)
                        and not name.startswith(f"{k}.tok_embeddings")
                        and not name.startswith(f"{k}.audio_language_projection")
549
                    )
Patrick von Platen's avatar
Patrick von Platen committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        name = re.sub(pattern, repl, name)

                if is_encoder:
                    name = self.whisper_encoder.load_weight((name, w))
                    loaded_weights.add(f"whisper_encoder.{name}")
                    continue

                if name in audio_params:
                    param = audio_params[name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                    loaded_weights.add(name)
                else:
                    yield (name, w)

        for name in self.language_model.load_weights(llm_weights_generator()):
            loaded_weights.add(f"language_model.{name}")

        # potentially manually add position embeddings
        sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
        if sin_key not in loaded_weights:
            # make sure we don't hit an error here
            loaded_weights.add(sin_key)

        return loaded_weights

579
    def maybe_update_quant_config(
580
581
        self, quant_config: QuantizationConfig
    ) -> QuantizationConfig:
582
583
584
585
586
587
588
589
        """
        Update quant config to so that ignored module and target module names
        match the vLLM model names.
        Right now this is specific for compressed-tensors format and
        load_format mistral.
        """
        remapping_rules = [
            (r"output", r"language_model.lm_head"),
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            (
                r"layers\.(\d+)\.attention\.wo",
                r"language_model.model.layers.\1.self_attn.out_proj",
            ),
            (
                r"layers\.(\d+)\.attention\.w(.*)",
                r"language_model.model.layers.\1.self_attn.\2_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w1",
                r"language_model.model.layers.\1.mlp.gate_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w2",
                r"language_model.model.layers.\1.mlp.down_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w3",
                r"language_model.model.layers.\1.mlp.up_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
                r"whisper_encoder.whisper_encoder.conv1",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
                r"whisper_encoder.whisper_encoder.conv2",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.0",
                r"audio_language_adapter.w_in",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.2",
                r"audio_language_adapter.w_out",
            ),
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        ]

        # Update ignore list
        if hasattr(quant_config, "ignore"):
            mistral_ignore = []
            for name in quant_config.ignore:
                mistral_name = name
                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        mistral_name = re.sub(pattern, repl, name)
                mistral_ignore.append(mistral_name)
            quant_config.ignore = mistral_ignore

        # Update target list
        if hasattr(quant_config, "config_groups"):
            config_groups = quant_config.config_groups
            for group_name in config_groups:
                if "targets" in config_groups[group_name]:
                    targets = []
                    for name in config_groups[group_name]["targets"]:
                        mistral_name = name
                        for pattern, repl in remapping_rules:
                            if re.fullmatch(pattern, name):
                                mistral_name = re.sub(pattern, repl, name)
                        targets.append(mistral_name)
                config_groups[group_name]["targets"] = targets
            quant_config.config_groups = config_groups

        return quant_config

Patrick von Platen's avatar
Patrick von Platen committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683

class AudioLanguageAdapter(nn.Module):
    def __init__(self, hidden_size: int, dim: int) -> None:
        super().__init__()
        self.w_in = nn.Linear(hidden_size, dim, bias=False)
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(dim, dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))


class VoxtralEncoderModel(nn.Module):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    mistral_remapping = [
Patrick von Platen's avatar
Patrick von Platen committed
684
        (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
685
686
687
688
689
690
691
692
        (
            r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
            r"whisper_encoder.conv1.\1",
        ),
        (
            r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
            r"whisper_encoder.conv2.\1",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
693
694
695
696
697
698
699
700
        (
            r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
            r"whisper_encoder.conv1.\1",
        ),  # noqa: E501
        (
            r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
            r"whisper_encoder.conv2.\1",
        ),  # noqa: E501
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.mlp.fc1.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.mlp.fc2.\2",
        ),
721
722
723
724
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
            r"whisper_encoder.layers.\1.mlp.fc3.\2",
        ),  # noqa: E501
725
726
727
728
729
730
731
732
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
            r"whisper_encoder.layers.\1.final_layer_norm.\2",
        ),
        (
            r"whisper_encoder\.transformer\.norm\.(weight|bias)",
            r"whisper_encoder.layer_norm.\1",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
733
734
735
736
737
738
739
740
741
742
743
    ]

    def __init__(
        self,
        vllm_config: VllmConfig,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
        self.dtype: torch.dtype = vllm_config.model_config.dtype
744
745
746
747
748
749
750
        self.is_causal = getattr(self.config, "is_causal", False)
        if self.is_causal:
            WhisperEncoderCls = WhisperCausalEncoder
        else:
            WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)

        self.whisper_encoder = WhisperEncoderCls(
751
752
753
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "whisper_encoder"),
        )
Patrick von Platen's avatar
Patrick von Platen committed
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        mel_filters = mel_filter_bank(
            num_frequency_bins=1 + self.config.window_size // 2,
            num_mel_bins=self.config.num_mel_bins,
            min_frequency=0.0,
            max_frequency=8000.0,
            sampling_rate=self.config.sampling_rate,
        )
        self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)

    def compute_whisper_melspec(
        self,
        audio_waveforms: torch.Tensor,
    ) -> torch.Tensor:
        input_dtype = audio_waveforms.dtype
768
        window = torch.hann_window(self.config.window_size).to(audio_waveforms.device)
Patrick von Platen's avatar
Patrick von Platen committed
769
770
771
772
773
774
775
        stft = torch.stft(
            audio_waveforms,
            self.config.window_size,
            self.config.hop_length,
            window=window,
            return_complex=True,
        )
776
        magnitudes = stft[..., :-1].abs() ** 2
Patrick von Platen's avatar
Patrick von Platen committed
777
778
779
780
781
782
783
784
        mel_spec = self.mel_filters.T @ magnitudes
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec.to(input_dtype)

    @property
    def downsample_factor(self) -> int:
785
786
787
        return (
            self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
        )
Patrick von Platen's avatar
Patrick von Platen committed
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814

    @property
    def chunk_size(self) -> int:
        return self.config.max_source_positions * self.downsample_factor

    def prepare_inputs_for_conv(
        self,
        audio_waveforms: list[torch.Tensor],
    ) -> tuple[torch.Tensor, list[int]]:
        assert isinstance(audio_waveforms, list)
        # list[num_mel_bins, seq_len]
        input_features = [
            self.compute_whisper_melspec(audio).to(self.dtype)
            for audio in audio_waveforms
        ]

        chunked_features: list[torch.Tensor] = []
        chunks_per_example: list[int] = []
        for feature in input_features:
            chunks = feature.split(self.chunk_size, dim=-1)
            chunked_features += chunks
            chunks_per_example.append(len(chunks))

        # [total_num_chunks, num_mel_bins, chunk_size]
        return torch.stack(chunked_features), chunks_per_example

    def forward(
815
        self, input_features: torch.Tensor | list[torch.Tensor]
Patrick von Platen's avatar
Patrick von Platen committed
816
817
818
819
820
    ) -> list[torch.Tensor]:
        if not isinstance(input_features, list):
            input_features = [input_features]

        # Split long inputs into chunks
821
        input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
Patrick von Platen's avatar
Patrick von Platen committed
822
823
824
825
826
827
828
829

        # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
        out = self.whisper_encoder([input_embeds])

        # Re-concatenate the chunks
        chunk_idx = 0
        results = []
        for n_chunks in chunks_per_example:
830
            result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
Patrick von Platen's avatar
Patrick von Platen committed
831
832
833
834
835
836
837
838
839
840
841
842
            results.append(result)
            chunk_idx += n_chunks

        return results

    def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        params_mapping = []

        if self.is_causal:
            # For `WhisperCausalEncoder` we need
            # some more renaming
            stacked_params_mapping.extend(
                [
                    (".mlp.gate_up_proj", ".mlp.fc1", 0),
                    (".mlp.gate_up_proj", ".mlp.fc3", 1),
                ]
            )
            params_mapping.extend(
                [
                    (".mlp.down_proj", ".mlp.fc2"),
                ]
            )
Patrick von Platen's avatar
Patrick von Platen committed
859
860
861
862
863
864
865
        params_dict = dict(self.named_parameters())

        name, loaded_weight = weight
        for pattern, repl in self.mistral_remapping:
            if re.fullmatch(pattern, name):
                name = re.sub(pattern, repl, name)

866
        for param_name, weight_name, shard_id in stacked_params_mapping:
Patrick von Platen's avatar
Patrick von Platen committed
867
868
869
870
871
872
873
874
875
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
876
877
878
879
880
            for param_name, weight_name in params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

Patrick von Platen's avatar
Patrick von Platen committed
881
            param = params_dict[name]
882
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
Patrick von Platen's avatar
Patrick von Platen committed
883
884
885
            weight_loader(param, loaded_weight)

        return name