voxtral.py 30.7 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable, Mapping, Sequence
6
from functools import partial
7
from typing import Literal, cast
Patrick von Platen's avatar
Patrick von Platen committed
8
9
10
11
12

import numpy as np
import regex as re
import torch
import torch.nn as nn
13
from mistral_common.audio import Audio, mel_filter_bank
14
15
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
Patrick von Platen's avatar
Patrick von Platen committed
16
17
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
18
from transformers import BatchFeature, WhisperConfig
Patrick von Platen's avatar
Patrick von Platen committed
19

20
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions
22
from vllm.inputs.data import PromptType, TokensPrompt
Patrick von Platen's avatar
Patrick von Platen committed
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Patrick von Platen's avatar
Patrick von Platen committed
25
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
27
from vllm.model_executor.models.module_mapping import MultiModelKeys
28
29
30
31
32
from vllm.model_executor.models.whisper import (
    WhisperEncoder,
    _create_fake_bias_for_k_proj,
)
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
Patrick von Platen's avatar
Patrick von Platen committed
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
40
41
42
43
44
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
45
from vllm.multimodal.processing import BaseDummyInputsBuilder
46
from vllm.multimodal.processing.processor import (
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalProcessingInfo,
50
    PlaceholderFeaturesInfo,
51
    ProcessorInputs,
52
53
    PromptReplacement,
    PromptUpdate,
54
    TimingContext,
55
)
Patrick von Platen's avatar
Patrick von Platen committed
56
from vllm.sequence import IntermediateTensors
57
58
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
59
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
Patrick von Platen's avatar
Patrick von Platen committed
60

61
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
62
from .utils import init_vllm_registered_model, maybe_prefix
Patrick von Platen's avatar
Patrick von Platen committed
63
64
65

logger = init_logger(__name__)

66
67
68
69
70
71
72
73
74
75
76
77
ISO639_1_SUPPORTED_LANGS = {
    "ar": "Arabic",
    "nl": "Dutch",
    "en": "English",
    "fr": "French",
    "de": "German",
    "hi": "Hindi",
    "it": "Italian",
    "pt": "Portuguese",
    "es": "Spanish",
}

Patrick von Platen's avatar
Patrick von Platen committed
78
79
80

class VoxtralProcessingInfo(BaseProcessingInfo):
    def get_tokenizer(self) -> MistralTokenizer:
81
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
Patrick von Platen's avatar
Patrick von Platen committed
82
83
84
85
86
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

87
88
89
90
91
92
    def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
        return self.ctx.init_processor(
            MistralCommonVoxtralProcessor,
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )
Patrick von Platen's avatar
Patrick von Platen committed
93

94
    def get_data_parser(self):
95
96
        feature_extractor = self.get_hf_processor().feature_extractor

97
        return MultiModalDataParser(
98
            target_sr=feature_extractor.sampling_rate,
99
            target_channels=1,
100
101
102
            expected_hidden_size=self._get_expected_hidden_size(),
        )

103
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Patrick von Platen's avatar
Patrick von Platen committed
104
105
106
107
108
109
110
111
112
113
114
115
116
        return {"audio": 5}  # Performance tends to degrade after 5

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"audio": self.get_max_audio_tokens()}

    def get_max_audio_tokens(self) -> int:
        return self.ctx.model_config.max_model_len

    def get_max_audio_array_len(self) -> int:
117
118
        feature_extractor = self.get_hf_processor().feature_extractor

Patrick von Platen's avatar
Patrick von Platen committed
119
        return self.get_max_audio_tokens() * int(
120
            feature_extractor.sampling_rate // feature_extractor.frame_rate
121
        )
Patrick von Platen's avatar
Patrick von Platen committed
122
123
124
125
126
127
128
129
130
131


class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
132
        mm_options: Mapping[str, BaseDummyOptions],
Patrick von Platen's avatar
Patrick von Platen committed
133
134
135
136
137
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)

        target_length = self.info.get_max_audio_array_len()

138
        audio_overrides = mm_options.get("audio")
139

Patrick von Platen's avatar
Patrick von Platen committed
140
        return {
141
            "audio": self._get_dummy_audios(
142
143
144
                length=target_length,
                num_audios=num_audios,
                overrides=audio_overrides,
145
            )
Patrick von Platen's avatar
Patrick von Platen committed
146
147
148
149
150
151
        }

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
152
        mm_options: Mapping[str, BaseDummyOptions],
Patrick von Platen's avatar
Patrick von Platen committed
153
154
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()
155
        feature_extractor = self.info.get_hf_processor().feature_extractor
Patrick von Platen's avatar
Patrick von Platen committed
156
157

        dummy_text = self.get_dummy_text(mm_counts)
158
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
Patrick von Platen's avatar
Patrick von Platen committed
159
160
161
162
163
164
165
        dummy_audios = dummy_mm_data.get("audio", [])

        audio_chunks: list[AudioChunk] = []
        format = "wav"
        for audio in dummy_audios:
            audio_item = Audio(
                audio_array=audio,
166
                sampling_rate=feature_extractor.sampling_rate,
Patrick von Platen's avatar
Patrick von Platen committed
167
168
169
170
171
                format=format,
            )
            chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
            audio_chunks.append(chunk)

172
173
174
175
176
        request = ChatCompletionRequest(
            messages=[
                UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
            ]
        )
Patrick von Platen's avatar
Patrick von Platen committed
177
178
179
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens

180
        dummy_mm_items = self.info.parse_mm_data(
181
182
183
184
185
            # whixtral tokenizer adds padding to the audio
            # so we need to update the audio arrays
            {**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
        )

186
        return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
Patrick von Platen's avatar
Patrick von Platen committed
187
188


189
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
Patrick von Platen's avatar
Patrick von Platen committed
190
191
192
193
194
195
196
    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))

197
198
199
200
201
202
203
204
205
    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
    ) -> None:
        # mistral_common's tokenizer's does not follow HF's placeholder norms
        # skip validation here
        ...

206
    def _call_hf_processor(
207
        self,
208
209
210
211
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
212
    ) -> BatchFeature:
213
214
215
216
217
218
219
220
221
222
223
224
225
        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])

        if audios:
            # MistralCommonVoxtralProcessor accepts "audio"
            mm_data["audio"] = audios

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
226

Patrick von Platen's avatar
Patrick von Platen committed
227
228
229
230
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
231
        out_mm_kwargs: MultiModalKwargsItems,
Patrick von Platen's avatar
Patrick von Platen committed
232
233
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
234
        feature_extractor = processor.feature_extractor
Patrick von Platen's avatar
Patrick von Platen committed
235
236

        audio_id = processor.audio_token_id
237
238
        out_mm_data = out_mm_kwargs.require_data()
        out_audio_items = out_mm_data.get("audio", [])
Patrick von Platen's avatar
Patrick von Platen committed
239
240

        def get_replacement(item_idx: int):
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            if item_idx < len(out_audio_items):
                out_audio_data = out_audio_items[item_idx].get_data()
                audio_arr = out_audio_data["audio_arrays"]
                if isinstance(audio_arr, (torch.Tensor, np.ndarray)):
                    audio_len = len(audio_arr)
                else:
                    raise TypeError(
                        "Unexpected type for audio_arrays in out_mm_kwargs: "
                        f"{type(audio_arr)}"
                    )
            else:
                # Fallback for unexpected processor outputs.
                audios = mm_items.get_items("audio", AudioProcessorItems)
                audio_len = audios.get_audio_length(item_idx)
Patrick von Platen's avatar
Patrick von Platen committed
255

256
            nb_audio_tokens = feature_extractor.get_num_audio_tokens(audio_len)
Patrick von Platen's avatar
Patrick von Platen committed
257
258
259
260
261
262
263
264
265
266
267
268
269

            return [audio_id] * nb_audio_tokens

        return [
            PromptReplacement(
                modality="audio",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
270
271
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
272
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
273
        prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
Patrick von Platen's avatar
Patrick von Platen committed
274
275

        # NOTE: The tokens are already inserted by the chat template
276
        return prompt_ids, mm_info, True
Patrick von Platen's avatar
Patrick von Platen committed
277
278


279
280
281
282
283
284
285
286
@MULTIMODAL_REGISTRY.register_processor(
    VoxtralMultiModalProcessor,
    info=VoxtralProcessingInfo,
    dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
287
    supported_languages = ISO639_1_SUPPORTED_LANGS
288
289
290
    # transformers' currently has limited support for MistralCommon backend
    # and cached_get_processor. Let's skip until fixed
    skip_warmup_audio_preprocessing = True
Patrick von Platen's avatar
Patrick von Platen committed
291

292
293
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
294
        "gate_up_proj": ["gate_proj", "up_proj"],
295
296
    }

Patrick von Platen's avatar
Patrick von Platen committed
297
298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
299
        self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
Patrick von Platen's avatar
Patrick von Platen committed
300

301
302
303
304
        # update quant config to so that ignored module and target module names
        # match the vLLM model names
        if hasattr(vllm_config, "quant_config"):
            vllm_config.quant_config = self.maybe_update_quant_config(
305
306
                vllm_config.quant_config
            )
307

Patrick von Platen's avatar
Patrick von Platen committed
308
309
310
311
        config = vllm_config.model_config.hf_config
        self.config = config
        self.downsample_factor = self.config.audio_config.downsample_factor

312
313
314
315
316
317
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
Patrick von Platen's avatar
Patrick von Platen committed
318

319
320
321
322
323
324
325
326
327
        with self._mark_tower_model(vllm_config, "audio"):
            self.whisper_encoder = VoxtralEncoderModel(
                vllm_config.with_hf_config(config.audio_config),
                prefix=maybe_prefix(prefix, "whisper_encoder"),
            )
            self.audio_language_adapter = AudioLanguageAdapter(
                hidden_size=config.audio_config.d_model * self.downsample_factor,
                dim=config.text_config.hidden_size,
            )
Patrick von Platen's avatar
Patrick von Platen committed
328

329
330
331
332
333
334
335
336
    def get_mm_mapping(self) -> MultiModelKeys:
        """Get module prefix for multimodal models to filter LoRA modules."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="audio_language_adapter",
            tower_model=["whisper_encoder"],
        )

Patrick von Platen's avatar
Patrick von Platen committed
337
338
    def forward(
        self,
339
        input_ids: torch.Tensor | None,
Patrick von Platen's avatar
Patrick von Platen committed
340
        positions: torch.Tensor,
341
342
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
343
        **kwargs: object,
344
    ) -> torch.Tensor | IntermediateTensors:
Patrick von Platen's avatar
Patrick von Platen committed
345
346
347
        if intermediate_tensors is not None:
            inputs_embeds = None

348
349
350
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Patrick von Platen's avatar
Patrick von Platen committed
351
352
353

        return hidden_states

354
    def embed_multimodal(
Patrick von Platen's avatar
Patrick von Platen committed
355
        self, **kwargs
356
    ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
Patrick von Platen's avatar
Patrick von Platen committed
357
358
359
360
361
362
363
364
365
366
        audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
        if audio_inputs is None:
            return None

        audio_embeddings = self.whisper_encoder(audio_inputs)

        for i, audio_embedding in enumerate(audio_embeddings):
            seq_len, dim = audio_embedding.shape
            # Pad such that seq_len is divisible by downsample_factor
            target_seq_len = self.downsample_factor * math.ceil(
367
368
                seq_len / self.downsample_factor
            )
Patrick von Platen's avatar
Patrick von Platen committed
369
370
371
372
373
            audio_embedding = torch.nn.functional.pad(
                audio_embedding,
                (0, 0, 0, target_seq_len - seq_len),
            )
            audio_embeddings[i] = audio_embedding.reshape(
374
375
                target_seq_len // self.downsample_factor, dim * self.downsample_factor
            )
Patrick von Platen's avatar
Patrick von Platen committed
376
377
378

        # Concat, project and resplit
        audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
379
380
381
382
        audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
        audio_embeddings = torch.split(
            audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
        )
Patrick von Platen's avatar
Patrick von Platen committed
383
384
385
386

        return audio_embeddings

    def _parse_and_validate_audio_arrays(
387
        self, **kwargs: object
388
    ) -> list[torch.Tensor] | None:
Patrick von Platen's avatar
Patrick von Platen committed
389
390
391
392
393
        audio_arrays = kwargs.pop("audio_arrays", None)
        if audio_arrays is None:
            return None

        if not isinstance(audio_arrays, (torch.Tensor, list)):
394
395
396
            raise ValueError(
                f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
            )
Patrick von Platen's avatar
Patrick von Platen committed
397
398
399
400
401
402
403
404

        if isinstance(audio_arrays, torch.Tensor):
            audio_arrays = list(audio_arrays.unbind(0))
        return audio_arrays

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
405
    ) -> torch.Tensor | None:
406
        return self.language_model.compute_logits(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
407
408

    @classmethod
409
    def get_speech_to_text_config(
410
        cls, model_config: ModelConfig, task_type: str
411
    ) -> SpeechToTextConfig:
412
        tokenizer = cached_tokenizer_from_config(model_config)
Patrick von Platen's avatar
Patrick von Platen committed
413
414
415
416
417
418
419
420
421
422
423
424
        audio_config = tokenizer.instruct.audio_encoder.audio_config
        max_audio_clip_s = audio_config.chunk_length_s
        sample_rate = audio_config.sampling_rate
        return SpeechToTextConfig(
            max_audio_clip_s=max_audio_clip_s,
            sample_rate=sample_rate,
            # mistral_common and whisper encoder take care of chunking
            min_energy_split_window_size=None,
        )

    @classmethod
    # for speech-to-text transcription
425
426
427
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
428
        model_config: ModelConfig,
429
        stt_config: SpeechToTextConfig,
430
        language: str | None,
431
432
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
433
        to_language: str | None,
434
    ) -> PromptType:
435
        tokenizer = cached_tokenizer_from_config(model_config)
436
437
        audio = Audio(audio, int(stt_config.sample_rate), format="wav")  # lossless
        req = TranscriptionRequest(
438
            model=model_config.model,
439
440
441
            audio=RawAudio.from_audio(audio),
            language=language,
        )
Patrick von Platen's avatar
Patrick von Platen committed
442
443

        tokenized = tokenizer.instruct.encode_transcription(req)
444
445
446
447

        return TokensPrompt(
            prompt_token_ids=tokenized.tokens,
            multi_modal_data={
448
449
450
451
                "audio": [
                    (audio.audio_array, stt_config.sample_rate)
                    for audio in tokenized.audios
                ],
452
453
            },
        )
Patrick von Platen's avatar
Patrick von Platen committed
454
455

    @classmethod
456
457
458
459
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
460
        model_config: ModelConfig,
461
    ) -> int | None:
Patrick von Platen's avatar
Patrick von Platen committed
462
        """
463
        Map from audio duration to number of audio tokens produced by the ASR
Patrick von Platen's avatar
Patrick von Platen committed
464
465
466
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
467
        tokenizer = cached_tokenizer_from_config(model_config)
468
469
        adapter = MistralCommonVoxtralProcessor(tokenizer)
        return adapter.feature_extractor.get_num_audio_tokens(
470
471
            int(audio_duration_s * stt_config.sample_rate)
        )
Patrick von Platen's avatar
Patrick von Platen committed
472

473
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Patrick von Platen's avatar
Patrick von Platen committed
474
        remapping_rules = [
Patrick von Platen's avatar
Patrick von Platen committed
475
            (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
Patrick von Platen's avatar
Patrick von Platen committed
476
477
            (r"mm_whisper_embeddings\.(.*)", r"\1"),
            (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
478
479
480
481
482
483
484
485
            (
                r"audio_language_adapter\.0\.weight",
                r"audio_language_adapter.w_in.weight",
            ),
            (
                r"audio_language_adapter\.2\.weight",
                r"audio_language_adapter.w_out.weight",
            ),
Patrick von Platen's avatar
Patrick von Platen committed
486
487
488
        ]

        audio_params = dict(
489
490
491
492
493
494
            nn.ModuleDict(
                {
                    "audio_language_adapter": self.audio_language_adapter,
                }
            ).named_parameters()
        )
495
        weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
Patrick von Platen's avatar
Patrick von Platen committed
496
497
498
499
500
501

        loaded_weights = set()

        def llm_weights_generator():
            nonlocal loaded_weights
            for name, w in weights:
Patrick von Platen's avatar
Patrick von Platen committed
502
503
504
505
506
507
508
509
510
                is_encoder = False
                for k in [
                    "mm_whisper_embeddings",
                    "mm_streams_embeddings.embedding_module",
                ]:
                    is_encoder |= (
                        name.startswith(k)
                        and not name.startswith(f"{k}.tok_embeddings")
                        and not name.startswith(f"{k}.audio_language_projection")
511
                    )
Patrick von Platen's avatar
Patrick von Platen committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        name = re.sub(pattern, repl, name)

                if is_encoder:
                    name = self.whisper_encoder.load_weight((name, w))
                    loaded_weights.add(f"whisper_encoder.{name}")
                    continue

                if name in audio_params:
                    param = audio_params[name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                    loaded_weights.add(name)
                else:
                    yield (name, w)

        for name in self.language_model.load_weights(llm_weights_generator()):
            loaded_weights.add(f"language_model.{name}")

        # potentially manually add position embeddings
        sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
        if sin_key not in loaded_weights:
            # make sure we don't hit an error here
            loaded_weights.add(sin_key)

        return loaded_weights

541
    def maybe_update_quant_config(
542
543
        self, quant_config: QuantizationConfig
    ) -> QuantizationConfig:
544
545
546
547
548
549
550
551
        """
        Update quant config to so that ignored module and target module names
        match the vLLM model names.
        Right now this is specific for compressed-tensors format and
        load_format mistral.
        """
        remapping_rules = [
            (r"output", r"language_model.lm_head"),
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
            (
                r"layers\.(\d+)\.attention\.wo",
                r"language_model.model.layers.\1.self_attn.out_proj",
            ),
            (
                r"layers\.(\d+)\.attention\.w(.*)",
                r"language_model.model.layers.\1.self_attn.\2_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w1",
                r"language_model.model.layers.\1.mlp.gate_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w2",
                r"language_model.model.layers.\1.mlp.down_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w3",
                r"language_model.model.layers.\1.mlp.up_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
                r"whisper_encoder.whisper_encoder.conv1",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
                r"whisper_encoder.whisper_encoder.conv2",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.0",
                r"audio_language_adapter.w_in",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.2",
                r"audio_language_adapter.w_out",
            ),
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        ]

        # Update ignore list
        if hasattr(quant_config, "ignore"):
            mistral_ignore = []
            for name in quant_config.ignore:
                mistral_name = name
                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        mistral_name = re.sub(pattern, repl, name)
                mistral_ignore.append(mistral_name)
            quant_config.ignore = mistral_ignore

        # Update target list
        if hasattr(quant_config, "config_groups"):
            config_groups = quant_config.config_groups
            for group_name in config_groups:
                if "targets" in config_groups[group_name]:
                    targets = []
                    for name in config_groups[group_name]["targets"]:
                        mistral_name = name
                        for pattern, repl in remapping_rules:
                            if re.fullmatch(pattern, name):
                                mistral_name = re.sub(pattern, repl, name)
                        targets.append(mistral_name)
                config_groups[group_name]["targets"] = targets
            quant_config.config_groups = config_groups

        return quant_config

Patrick von Platen's avatar
Patrick von Platen committed
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645

class AudioLanguageAdapter(nn.Module):
    def __init__(self, hidden_size: int, dim: int) -> None:
        super().__init__()
        self.w_in = nn.Linear(hidden_size, dim, bias=False)
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(dim, dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))


class VoxtralEncoderModel(nn.Module):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    mistral_remapping = [
Patrick von Platen's avatar
Patrick von Platen committed
646
        (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
647
648
649
650
651
652
653
654
        (
            r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
            r"whisper_encoder.conv1.\1",
        ),
        (
            r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
            r"whisper_encoder.conv2.\1",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
655
656
657
658
659
660
661
662
        (
            r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
            r"whisper_encoder.conv1.\1",
        ),  # noqa: E501
        (
            r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
            r"whisper_encoder.conv2.\1",
        ),  # noqa: E501
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.mlp.fc1.\2",
        ),
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)",  # noqa: E501
            r"whisper_encoder.layers.\1.mlp.fc2.\2",
        ),
683
684
685
686
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
            r"whisper_encoder.layers.\1.mlp.fc3.\2",
        ),  # noqa: E501
687
688
689
690
691
692
693
694
        (
            r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
            r"whisper_encoder.layers.\1.final_layer_norm.\2",
        ),
        (
            r"whisper_encoder\.transformer\.norm\.(weight|bias)",
            r"whisper_encoder.layer_norm.\1",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
695
696
697
698
699
700
701
702
703
704
705
    ]

    def __init__(
        self,
        vllm_config: VllmConfig,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
        self.dtype: torch.dtype = vllm_config.model_config.dtype
706
707
708
709
710
711
712
        self.is_causal = getattr(self.config, "is_causal", False)
        if self.is_causal:
            WhisperEncoderCls = WhisperCausalEncoder
        else:
            WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)

        self.whisper_encoder = WhisperEncoderCls(
713
714
715
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "whisper_encoder"),
        )
Patrick von Platen's avatar
Patrick von Platen committed
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        mel_filters = mel_filter_bank(
            num_frequency_bins=1 + self.config.window_size // 2,
            num_mel_bins=self.config.num_mel_bins,
            min_frequency=0.0,
            max_frequency=8000.0,
            sampling_rate=self.config.sampling_rate,
        )
        self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)

    def compute_whisper_melspec(
        self,
        audio_waveforms: torch.Tensor,
    ) -> torch.Tensor:
        input_dtype = audio_waveforms.dtype
Andy Lo's avatar
Andy Lo committed
730
731
732
        window = torch.hann_window(
            self.config.window_size, device=audio_waveforms.device
        )
Patrick von Platen's avatar
Patrick von Platen committed
733
734
735
736
737
738
739
        stft = torch.stft(
            audio_waveforms,
            self.config.window_size,
            self.config.hop_length,
            window=window,
            return_complex=True,
        )
740
        magnitudes = stft[..., :-1].abs() ** 2
Patrick von Platen's avatar
Patrick von Platen committed
741
742
        mel_spec = self.mel_filters.T @ magnitudes
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
743
744
745
746
747
748
749
750
751
752
753
754
755

        if global_log_mel_max := self.config.global_log_mel_max:
            if not isinstance(global_log_mel_max, float):
                raise TypeError(f"{global_log_mel_max=} needs to be of type float.")
            log_spec_max = torch.tensor(
                global_log_mel_max,
                device=log_spec.device,
                dtype=log_spec.dtype,
            )
        else:
            log_spec_max = log_spec.max()

        log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
Patrick von Platen's avatar
Patrick von Platen committed
756
757
758
759
760
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec.to(input_dtype)

    @property
    def downsample_factor(self) -> int:
761
762
763
        return (
            self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
        )
Patrick von Platen's avatar
Patrick von Platen committed
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790

    @property
    def chunk_size(self) -> int:
        return self.config.max_source_positions * self.downsample_factor

    def prepare_inputs_for_conv(
        self,
        audio_waveforms: list[torch.Tensor],
    ) -> tuple[torch.Tensor, list[int]]:
        assert isinstance(audio_waveforms, list)
        # list[num_mel_bins, seq_len]
        input_features = [
            self.compute_whisper_melspec(audio).to(self.dtype)
            for audio in audio_waveforms
        ]

        chunked_features: list[torch.Tensor] = []
        chunks_per_example: list[int] = []
        for feature in input_features:
            chunks = feature.split(self.chunk_size, dim=-1)
            chunked_features += chunks
            chunks_per_example.append(len(chunks))

        # [total_num_chunks, num_mel_bins, chunk_size]
        return torch.stack(chunked_features), chunks_per_example

    def forward(
791
        self, input_features: torch.Tensor | list[torch.Tensor]
Patrick von Platen's avatar
Patrick von Platen committed
792
793
794
795
796
    ) -> list[torch.Tensor]:
        if not isinstance(input_features, list):
            input_features = [input_features]

        # Split long inputs into chunks
797
        input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
Patrick von Platen's avatar
Patrick von Platen committed
798
799
800
801
802
803
804
805

        # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
        out = self.whisper_encoder([input_embeds])

        # Re-concatenate the chunks
        chunk_idx = 0
        results = []
        for n_chunks in chunks_per_example:
806
            result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
Patrick von Platen's avatar
Patrick von Platen committed
807
808
809
810
811
812
813
814
815
816
817
818
            results.append(result)
            chunk_idx += n_chunks

        return results

    def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
        params_mapping = []

        if self.is_causal:
            # For `WhisperCausalEncoder` we need
            # some more renaming
            stacked_params_mapping.extend(
                [
                    (".mlp.gate_up_proj", ".mlp.fc1", 0),
                    (".mlp.gate_up_proj", ".mlp.fc3", 1),
                ]
            )
            params_mapping.extend(
                [
                    (".mlp.down_proj", ".mlp.fc2"),
                ]
            )
Patrick von Platen's avatar
Patrick von Platen committed
835
836
837
838
839
840
841
        params_dict = dict(self.named_parameters())

        name, loaded_weight = weight
        for pattern, repl in self.mistral_remapping:
            if re.fullmatch(pattern, name):
                name = re.sub(pattern, repl, name)

842
        for param_name, weight_name, shard_id in stacked_params_mapping:
Patrick von Platen's avatar
Patrick von Platen committed
843
844
845
846
847
848
849
850
851
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
852
853
854
855
856
            for param_name, weight_name in params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

Patrick von Platen's avatar
Patrick von Platen committed
857
            param = params_dict[name]
858
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
Patrick von Platen's avatar
Patrick von Platen committed
859
860
861
            weight_loader(param, loaded_weight)

        return name