voxtral.py 30 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from math import ceil
8
from typing import Literal, Optional, Union, cast
Patrick von Platen's avatar
Patrick von Platen committed
9
10
11
12
13
14
15
16
17
18
19

import numpy as np
import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
                                                       TextChunk, UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
20
from transformers import BatchFeature, TensorType, WhisperConfig
Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
25
from transformers.tokenization_utils_base import TextInput

from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
26
from vllm.model_executor.layers.quantization import QuantizationConfig
Patrick von Platen's avatar
Patrick von Platen committed
27
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
Patrick von Platen's avatar
Patrick von Platen committed
30
# yapf: disable
31
from vllm.model_executor.models.whisper import WhisperEncoder
Patrick von Platen's avatar
Patrick von Platen committed
32
33
34
# yapf: enable
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
35
36
                                    MultiModalKwargsItems, MultiModalUUIDDict,
                                    NestedTensors)
Patrick von Platen's avatar
Patrick von Platen committed
37
38
39
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
                                   MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
40
41
                                        BaseProcessingInfo,
                                        MultiModalProcessingInfo,
Patrick von Platen's avatar
Patrick von Platen committed
42
43
44
45
46
47
                                        PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
                                               cached_tokenizer_from_config)

48
49
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
Patrick von Platen's avatar
Patrick von Platen committed
50
51
52

logger = init_logger(__name__)

53
54
55
56
57
58
59
60
61
62
63
64
ISO639_1_SUPPORTED_LANGS = {
    "ar": "Arabic",
    "nl": "Dutch",
    "en": "English",
    "fr": "French",
    "de": "German",
    "hi": "Hindi",
    "it": "Italian",
    "pt": "Portuguese",
    "es": "Spanish",
}

Patrick von Platen's avatar
Patrick von Platen committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

class VoxtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
    :class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    @cached_property
    def _audio_processor(self) -> AudioEncoder:
        audio_encoder = self.tokenizer.instruct.audio_encoder
        assert isinstance(audio_encoder, AudioEncoder)
        return audio_encoder

    @cached_property
    def audio_token_id(self) -> int:
        return self._audio_processor.special_ids.audio

    @cached_property
    def begin_audio_token_id(self) -> int:
        return self._audio_processor.special_ids.begin_audio

    # @cached_property
    # def begin_transcript_token_id(self) -> int:
    #     return self._audio_processor.special_ids.begin_transcript

    # @cached_property
    # def end_transcript_token_id(self) -> int:
    #     return self._audio_processor.special_ids.end_transcript

    @cached_property
    def sampling_rate(self) -> int:
        return self._audio_processor.audio_config.sampling_rate

    @cached_property
    def frame_rate(self) -> float:
        return self._audio_processor.audio_config.frame_rate

    def get_num_audio_tokens(
        self,
        audio_length: int,
    ) -> int:
        pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
            audio_length, self.sampling_rate)
        return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        audios: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if audios is None:
            audios = []
        if not isinstance(audios, list):
            audios = [audios]

        if not audios:
            input_ids = self.tokenizer(text).input_ids
            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411.")

        audios_tokens = list[torch.Tensor]()
        audios_processed = list[torch.Tensor]()
        for audio in audios:
            assert isinstance(audio, np.ndarray)
            assert audio.ndim == 1

            # pad if necessary
            audio = self._audio_processor.pad(audio, self.sampling_rate)

            audio_tokens = [
                self.begin_audio_token_id
            ] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio))

            audios_tokens.append(torch.tensor(audio_tokens))
            audios_processed.append(torch.tensor(audio))

159
160
161
162
163
164
        return BatchFeature({
            "input_ids":
            torch.cat(audios_tokens)[None].expand(len(text), -1),
            "audio_arrays":
            audios_processed,
        })
Patrick von Platen's avatar
Patrick von Platen committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264


class VoxtralProcessingInfo(BaseProcessingInfo):

    def get_tokenizer(self) -> MistralTokenizer:
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

    def get_hf_processor(self) -> VoxtralProcessorAdapter:
        return VoxtralProcessorAdapter(self.get_tokenizer())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": 5}  # Performance tends to degrade after 5

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"audio": self.get_max_audio_tokens()}

    def get_max_audio_tokens(self) -> int:
        return self.ctx.model_config.max_model_len

    def get_max_audio_array_len(self) -> int:
        processor = self.get_hf_processor()
        return self.get_max_audio_tokens() * int(
            processor.sampling_rate // processor.frame_rate)


class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)

        target_length = self.info.get_max_audio_array_len()

        return {
            "audio":
            self._get_dummy_audios(length=target_length, num_audios=num_audios)
        }

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()

        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
        dummy_audios = dummy_mm_data.get("audio", [])

        audio_chunks: list[AudioChunk] = []
        format = "wav"
        for audio in dummy_audios:
            audio_item = Audio(
                audio_array=audio,
                sampling_rate=self.info.get_hf_processor().sampling_rate,
                format=format,
            )
            chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
            audio_chunks.append(chunk)

        request = ChatCompletionRequest(messages=[
            UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
        ])
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens
        # whixtral tokenizer adds padding to the audio
        # so we need to update the audio arrays
        dummy_mm_data["audio"] = [a.audio_array for a in res.audios]

        return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)


class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
                                 ):

    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
265
        out_mm_kwargs: MultiModalKwargsItems,
Patrick von Platen's avatar
Patrick von Platen committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        audio_id = processor.audio_token_id

        def get_replacement(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio_len = audios.get_audio_length(item_idx)

            nb_audio_tokens = processor.get_num_audio_tokens(audio_len)

            return [audio_id] * nb_audio_tokens

        return [
            PromptReplacement(
                modality="audio",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
293
        mm_uuids: Optional[MultiModalUUIDDict] = None,
294
295
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
        prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
Patrick von Platen's avatar
Patrick von Platen committed
296
297
298
299
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
300
            mm_uuids=mm_uuids,
Patrick von Platen's avatar
Patrick von Platen committed
301
302
303
        )

        # NOTE: The tokens are already inserted by the chat template
304
        return prompt_ids, mm_info, True
Patrick von Platen's avatar
Patrick von Platen committed
305
306
307
308
309
310
311
312
313
314

    def _get_data_parser(self) -> MultiModalDataParser:
        sampling_rate = self.info.get_hf_processor().sampling_rate
        return MultiModalDataParser(target_sr=sampling_rate)


@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor,
                                        info=VoxtralProcessingInfo,
                                        dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
315
316
                                      SupportsPP, SupportsLoRA,
                                      SupportsTranscription):
317
    supported_languages = ISO639_1_SUPPORTED_LANGS
Patrick von Platen's avatar
Patrick von Platen committed
318

319
320
321
322
323
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

Patrick von Platen's avatar
Patrick von Platen committed
324
325
326
327
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)

328
329
330
331
332
333
        # update quant config to so that ignored module and target module names
        # match the vLLM model names
        if hasattr(vllm_config, "quant_config"):
            vllm_config.quant_config = self.maybe_update_quant_config(
                vllm_config.quant_config)

Patrick von Platen's avatar
Patrick von Platen committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        config = vllm_config.model_config.hf_config
        self.config = config
        self.downsample_factor = self.config.audio_config.downsample_factor

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
        self.whisper_encoder = VoxtralEncoderModel(
            vllm_config.with_hf_config(config.audio_config),
            prefix=maybe_prefix(prefix, "whisper_encoder"),
        )
        self.audio_language_adapter = AudioLanguageAdapter(
            hidden_size=config.audio_config.d_model * self.downsample_factor,
            dim=config.text_config.hidden_size,
        )

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

355
356
357
358
359
360
361
362
    def get_mm_mapping(self) -> MultiModelKeys:
        """Get module prefix for multimodal models to filter LoRA modules."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="audio_language_adapter",
            tower_model=["whisper_encoder"],
        )

Patrick von Platen's avatar
Patrick von Platen committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def get_multimodal_embeddings(
        self, **kwargs
    ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...],
               None]:
        audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
        if audio_inputs is None:
            return None

        audio_embeddings = self.whisper_encoder(audio_inputs)

        for i, audio_embedding in enumerate(audio_embeddings):
            seq_len, dim = audio_embedding.shape
            # Pad such that seq_len is divisible by downsample_factor
            target_seq_len = self.downsample_factor * math.ceil(
                seq_len / self.downsample_factor)
            audio_embedding = torch.nn.functional.pad(
                audio_embedding,
                (0, 0, 0, target_seq_len - seq_len),
            )
            audio_embeddings[i] = audio_embedding.reshape(
                target_seq_len // self.downsample_factor,
                dim * self.downsample_factor)

        # Concat, project and resplit
        audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
        audio_embeddings_packed = self.audio_language_adapter(
            audio_embeddings_packed)
        audio_embeddings = torch.split(audio_embeddings_packed,
                                       [a.shape[0] for a in audio_embeddings],
                                       dim=0)

        return audio_embeddings

    def _parse_and_validate_audio_arrays(
            self, **kwargs: object) -> Union[list[torch.Tensor], None]:
        audio_arrays = kwargs.pop("audio_arrays", None)
        if audio_arrays is None:
            return None

        if not isinstance(audio_arrays, (torch.Tensor, list)):
            raise ValueError("Incorrect type of audio_arrays. "
                             f"Got type: {type(audio_arrays)}")

        audio_arrays = flatten_bn(audio_arrays)
        if isinstance(audio_arrays, torch.Tensor):
            audio_arrays = list(audio_arrays.unbind(0))
        return audio_arrays

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
433
        return self.language_model.compute_logits(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

    @classmethod
    def get_speech_to_text_config(cls, model_config: ModelConfig,
                                  task_type: str) -> SpeechToTextConfig:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio_config = tokenizer.instruct.audio_encoder.audio_config
        max_audio_clip_s = audio_config.chunk_length_s
        sample_rate = audio_config.sampling_rate
        return SpeechToTextConfig(
            max_audio_clip_s=max_audio_clip_s,
            sample_rate=sample_rate,
            # mistral_common and whisper encoder take care of chunking
            min_energy_split_window_size=None,
        )

    @classmethod
    # for speech-to-text transcription
    def get_generation_prompt(cls, audio: np.ndarray,
                              model_config: ModelConfig,
453
                              stt_config: SpeechToTextConfig,
454
455
456
457
                              language: Optional[str],
                              task_type: Literal["transcribe", "translate"],
                              request_prompt: str,
                              to_language: Optional[str]) -> PromptType:
Patrick von Platen's avatar
Patrick von Platen committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        tokenizer = cached_tokenizer_from_config(model_config)
        audio = Audio(audio, int(stt_config.sample_rate),
                      format="wav")  # lossless
        req = TranscriptionRequest(model=model_config.model,
                                   audio=RawAudio.from_audio(audio),
                                   language=language)

        tokenized = tokenizer.instruct.encode_transcription(req)
        audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
        prompts_dict = {"multi_modal_data": {"audio": audio}}
        prompts_dict["prompt_token_ids"] = tokenized.tokens
        return cast(PromptType, prompts_dict)

    @classmethod
    def get_num_audio_tokens(cls, audio_duration_s: float,
                             stt_config: SpeechToTextConfig,
                             model_config: ModelConfig) -> Optional[int]:
        """
        Map from audio duration to number of audio tokens produced by the ASR 
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
        tokenizer = cached_tokenizer_from_config(model_config)
        adapter = VoxtralProcessorAdapter(tokenizer)
        return adapter.get_num_audio_tokens(
            int(audio_duration_s * stt_config.sample_rate))

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        # fmt: off
        remapping_rules = [
            (r"mm_whisper_embeddings\.(.*)", r"\1"),
            (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
            (r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"),  # noqa: E501
            (r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"),  # noqa: E501
        ]
        # fmt: on

        audio_params = dict(
            nn.ModuleDict({
                "audio_language_adapter":
                self.audio_language_adapter,
            }).named_parameters())

        loaded_weights = set()

        def llm_weights_generator():
            nonlocal loaded_weights
            for name, w in weights:
                is_encoder = (
                    name.startswith("mm_whisper_embeddings") and
                    not name.startswith("mm_whisper_embeddings.tok_embeddings")
                    and not name.startswith(
                        "mm_whisper_embeddings.audio_language_projection"))

                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        name = re.sub(pattern, repl, name)

                if is_encoder:
                    name = self.whisper_encoder.load_weight((name, w))
                    loaded_weights.add(f"whisper_encoder.{name}")
                    continue

                if name in audio_params:
                    param = audio_params[name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                    loaded_weights.add(name)
                else:
                    yield (name, w)

        for name in self.language_model.load_weights(llm_weights_generator()):
            loaded_weights.add(f"language_model.{name}")

        # potentially manually add position embeddings
        sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
        if sin_key not in loaded_weights:
            # make sure we don't hit an error here
            loaded_weights.add(sin_key)

        return loaded_weights

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    def maybe_update_quant_config(
            self, quant_config: QuantizationConfig) -> QuantizationConfig:
        """
        Update quant config to so that ignored module and target module names
        match the vLLM model names.
        Right now this is specific for compressed-tensors format and
        load_format mistral.
        """
        remapping_rules = [
            (r"output", r"language_model.lm_head"),
            (r"layers\.(\d+)\.attention\.wo",
             r"language_model.model.layers.\1.self_attn.out_proj"),
            (r"layers\.(\d+)\.attention\.w(.*)",
             r"language_model.model.layers.\1.self_attn.\2_proj"),
            (r"layers\.(\d+)\.feed_forward\.w1",
             r"language_model.model.layers.\1.mlp.gate_proj"),
            (r"layers\.(\d+)\.feed_forward\.w2",
             r"language_model.model.layers.\1.mlp.down_proj"),
            (r"layers\.(\d+)\.feed_forward\.w3",
             r"language_model.model.layers.\1.mlp.up_proj"),
            (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
             r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
             ),
564
565
566
            (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
             r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
             ),
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
             r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
            (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
             r"whisper_encoder.whisper_encoder.conv1"),
            (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
             r"whisper_encoder.whisper_encoder.conv2"),
            (r"mm_whisper_embeddings\.audio_language_projection\.0",
             r"audio_language_adapter.w_in"),
            (r"mm_whisper_embeddings\.audio_language_projection\.2",
             r"audio_language_adapter.w_out"),
        ]

        # Update ignore list
        if hasattr(quant_config, "ignore"):
            mistral_ignore = []
            for name in quant_config.ignore:
                mistral_name = name
                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        mistral_name = re.sub(pattern, repl, name)
                mistral_ignore.append(mistral_name)
            quant_config.ignore = mistral_ignore

        # Update target list
        if hasattr(quant_config, "config_groups"):
            config_groups = quant_config.config_groups
            for group_name in config_groups:
                if "targets" in config_groups[group_name]:
                    targets = []
                    for name in config_groups[group_name]["targets"]:
                        mistral_name = name
                        for pattern, repl in remapping_rules:
                            if re.fullmatch(pattern, name):
                                mistral_name = re.sub(pattern, repl, name)
                        targets.append(mistral_name)
                config_groups[group_name]["targets"] = targets
            quant_config.config_groups = config_groups

        return quant_config

Patrick von Platen's avatar
Patrick von Platen committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762

class AudioLanguageAdapter(nn.Module):

    def __init__(self, hidden_size: int, dim: int) -> None:
        super().__init__()
        self.w_in = nn.Linear(hidden_size, dim, bias=False)
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(dim, dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))


class VoxtralEncoderModel(nn.Module):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    # fmt: off
    mistral_remapping = [
        (r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
        (r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
        (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
        (r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
    ]
    # fmt: on

    def __init__(
        self,
        vllm_config: VllmConfig,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
        self.dtype: torch.dtype = vllm_config.model_config.dtype
        self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
                                              prefix=maybe_prefix(
                                                  prefix, "whisper_encoder"),
                                              init_in_fp32=True)
        mel_filters = mel_filter_bank(
            num_frequency_bins=1 + self.config.window_size // 2,
            num_mel_bins=self.config.num_mel_bins,
            min_frequency=0.0,
            max_frequency=8000.0,
            sampling_rate=self.config.sampling_rate,
        )
        self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)

    def compute_whisper_melspec(
        self,
        audio_waveforms: torch.Tensor,
    ) -> torch.Tensor:
        input_dtype = audio_waveforms.dtype
        window = torch.hann_window(self.config.window_size).to(
            audio_waveforms.device)
        stft = torch.stft(
            audio_waveforms,
            self.config.window_size,
            self.config.hop_length,
            window=window,
            return_complex=True,
        )
        magnitudes = stft[..., :-1].abs()**2
        mel_spec = self.mel_filters.T @ magnitudes
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec.to(input_dtype)

    @property
    def downsample_factor(self) -> int:
        return self.whisper_encoder.conv1.stride[
            0] * self.whisper_encoder.conv2.stride[0]

    @property
    def chunk_size(self) -> int:
        return self.config.max_source_positions * self.downsample_factor

    def prepare_inputs_for_conv(
        self,
        audio_waveforms: list[torch.Tensor],
    ) -> tuple[torch.Tensor, list[int]]:
        assert isinstance(audio_waveforms, list)
        # list[num_mel_bins, seq_len]
        input_features = [
            self.compute_whisper_melspec(audio).to(self.dtype)
            for audio in audio_waveforms
        ]

        chunked_features: list[torch.Tensor] = []
        chunks_per_example: list[int] = []
        for feature in input_features:
            chunks = feature.split(self.chunk_size, dim=-1)
            chunked_features += chunks
            chunks_per_example.append(len(chunks))

        # [total_num_chunks, num_mel_bins, chunk_size]
        return torch.stack(chunked_features), chunks_per_example

    def forward(
        self, input_features: Union[torch.Tensor, list[torch.Tensor]]
    ) -> list[torch.Tensor]:
        if not isinstance(input_features, list):
            input_features = [input_features]

        # Split long inputs into chunks
        input_embeds, chunks_per_example = (
            self.prepare_inputs_for_conv(input_features))

        # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
        out = self.whisper_encoder([input_embeds])

        # Re-concatenate the chunks
        chunk_idx = 0
        results = []
        for n_chunks in chunks_per_example:
            result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1)
            results.append(result)
            chunk_idx += n_chunks

        return results

    def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())

        name, loaded_weight = weight
        for pattern, repl in self.mistral_remapping:
            if re.fullmatch(pattern, name):
                name = re.sub(pattern, repl, name)

        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        return name