kimi_audio.py 25.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""

import os
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Literal

import numpy as np
import torch
import torch.nn as nn
from safetensors import safe_open
from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig

from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
)
from vllm.model_executor.models.interfaces import (
    SupportsMultiModal,
    SupportsPP,
    SupportsTranscription,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.models.whisper_utils import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (
    AudioItem,
    DictEmbeddingItems,
    ModalityData,
    ModalityDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseDummyInputsBuilder,
    BaseProcessingInfo,
    PromptReplacement,
)
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.transformers_utils.processor import cached_feature_extractor_from_config
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
from vllm.v1.sample.metadata import SamplingMetadata

# Kimi-Audio constants
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"


def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
    """Compute output lengths after Whisper feature extraction.

    Whisper processes audio through multiple conv layers with stride=2,
    producing 13 output features per 100 input samples.
    """
    input_lengths_leave = input_lengths % 100
    feat_lengths = (input_lengths_leave - 1) // 2 + 1
    output_lengths = (
        ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
    )
    return output_lengths


class KimiAudioWhisperEncoder(WhisperEncoder):
    """WhisperEncoder for Kimi-Audio with packed_modules_mapping."""

    # packed_modules_mapping for Q/K/V fusion during weight loading
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "kv_proj": ["k_proj", "v_proj"],
    }

    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
    ):
        # Load Whisper config from subfolder (authoritative source)
        # Kimi-Audio stores Whisper config in whisper-large-v3/config.json
        model_path = vllm_config.model_config.model
        whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)

        # Load WhisperConfig from the subfolder
        whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)

        # Temporarily replace hf_config for WhisperEncoder.__init__()
        original_config = vllm_config.model_config.hf_config
        vllm_config.model_config.hf_config = whisper_config

        super().__init__(
            vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
        )

        # Restore original config
        vllm_config.model_config.hf_config = original_config


# -----------------------------------------------------------------------------
# Processing Info, Dummy Inputs, and MultiModal Processor
# (Following Qwen3ASR pattern - same file as model)
# -----------------------------------------------------------------------------


class KimiAudioProcessingInfo(BaseProcessingInfo):
    """Processing info for vLLM registry."""

    def get_hf_config(self):
        return self.ctx.model_config.hf_config

    def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
        """Get KimiAudioProcessor with feature extractor and tokenizer."""
        # Use vLLM's cached loader for feature extractor
        feature_extractor = cached_feature_extractor_from_config(
            self.ctx.model_config,
            subfolder=KIMIA_WHISPER_SUBFOLDER,
        )

        # Use vLLM's standard tokenizer loading (respects tokenizer_mode)
        tokenizer = self.get_tokenizer()

        # Construct processor directly
        return KimiAudioProcessor(
            feature_extractor=feature_extractor,
            tokenizer=tokenizer,
        )

    def get_feature_extractor(self, **kwargs: object):
        """Get feature extractor using vLLM's cached loader."""
        return cached_feature_extractor_from_config(
            self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
        )

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"audio": 1}

    def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
        """Get data parser for audio inputs."""
        return KimiAudioMultiModalDataParser(
            expected_hidden_size=self._get_expected_hidden_size(),
        )


class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
    """Dummy inputs builder for vLLM registry."""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
        """Return dummy text as token IDs directly."""
        num_audios = mm_counts.get("audio", 0)
        if num_audios == 0:
            return [198]  # "Transcribe" tokenized
        # Return as token IDs directly to avoid tokenizer issues
        return [
            KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
            KimiAudioProcessor.KIMIA_TEXT_BLANK,
            KimiAudioProcessor.KIMIA_MEDIA_END,
        ] * num_audios

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, Any] | None = None,
    ) -> dict[str, Any]:
        num_audios = mm_counts.get("audio", 0)
        if num_audios == 0:
            return {}

        feature_extractor = self.info.get_feature_extractor()
        target_audio_length = (
            min(feature_extractor.chunk_length, 30) * feature_extractor.sampling_rate
        )

        return {
            "audio": self._get_dummy_audios(
                length=target_audio_length, num_audios=num_audios
            ),
        }


# Field config for Kimi-Audio multimodal data
_KIMIAUDIO_FIELD_CONFIG = {
    "whisper_input_features": MultiModalFieldConfig.batched("audio"),
    "feature_attention_mask": MultiModalFieldConfig.batched("audio"),
}


class KimiAudioMultiModalDataParser(MultiModalDataParser):
    """Custom data parser for Kimi-Audio multimodal data."""

    def __init__(self, **kwargs):
        # Whisper expects 16kHz audio
        super().__init__(target_sr=16000, **kwargs)

    def _parse_audio_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"whisper_input_features", "feature_attention_mask"},
                fields_factory=lambda hf_inputs: _KIMIAUDIO_FIELD_CONFIG,
            )

        return super()._parse_audio_data(data)


class KimiAudioMultiModalProcessor(BaseMultiModalProcessor[KimiAudioProcessingInfo]):
    """vLLM multi-modal processor wrapper for Kimi-Audio."""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        """Call the HuggingFace processor."""
        # Convert mm_data format: {'audios': [...]} -> {'audio': ...}
        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])

        # Convert audio format: [(array, sr), ...] -> [array, ...]
        # KimiAudioProcessor expects raw numpy arrays
        if audios:
            audio_arrays = []
            for aud in audios:
                if isinstance(aud, (tuple, list)) and len(aud) == 2:
                    # Format: (audio_array, sampling_rate)
                    audio_arrays.append(aud[0])
                elif isinstance(aud, np.ndarray):
                    audio_arrays.append(aud)
                else:
                    audio_arrays.append(aud)
            mm_data["audio"] = audio_arrays

        # Use the context's call_hf_processor for proper handling
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
            dict(text=prompt, **mm_data),
            dict(**mm_kwargs, **tok_kwargs),
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, Any]:
        """Get multi-modal field configuration."""
        return _KIMIAUDIO_FIELD_CONFIG

    def _get_prompt_updates(
        self,
        mm_items,
        hf_processor_mm_kwargs,
        out_mm_kwargs,
    ) -> Sequence[PromptReplacement]:
        """Get prompt updates for audio tokens."""
        # Get audio feature lengths from processed output
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")

        if feature_attention_mask is not None:
            audio_output_lens = _get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)
            )
            audio_output_lengths = audio_output_lens.tolist()
        else:
            audio_output_lengths = []

        def get_replacement_kimiaudio(item_idx: int):
            num_features = (
                audio_output_lengths[item_idx]
                if item_idx < len(audio_output_lengths)
                else 376
            )
            if num_features == 0:
                num_features = 376  # Default Kimi-Audio sequence length
            # Return the placeholder token ID repeated num_features times
            return [KimiAudioProcessor.KIMIA_TEXT_BLANK] * num_features

        # Use the token ID as target (as a list)
        return [
            PromptReplacement(
                modality="audio",
                target=[KimiAudioProcessor.KIMIA_TEXT_BLANK],
                replacement=get_replacement_kimiaudio,
            ),
        ]


# -----------------------------------------------------------------------------
# Model Definition
# -----------------------------------------------------------------------------


class KimiAudioMultiModalProjector(nn.Module):
    """Projects Whisper features to LLM embedding space.

    Kimi-Audio VQ-Adaptor architecture:
    Custom Whisper (5120) → Linear[5120→3584] → Linear[3584→3584] → LayerNorm
    """

    def __init__(
        self,
        whisper_dim: int = 5120,  # Kimi-Audio custom Whisper encoder dim
        llm_dim: int = 3584,
        prefix: str = "",
    ):
        super().__init__()
        self.whisper_dim = whisper_dim
        self.llm_dim = llm_dim

        # VQ-Adaptor layers (exact checkpoint structure)
        # layers.0: Linear[5120 → 3584]
        self.vq_adaptor_layers_0 = nn.Linear(whisper_dim, llm_dim)
        # layers.3: Linear[3584 → 3584]
        self.vq_adaptor_layers_3 = nn.Linear(llm_dim, llm_dim)
        # layers.4: LayerNorm[3584]
        self.vq_adaptor_layers_4 = nn.LayerNorm(llm_dim)

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        # Project: [B, T, 5120] → [B, T, 3584]
        hidden = self.vq_adaptor_layers_0(audio_features)
        hidden = torch.nn.functional.gelu(hidden)
        hidden = self.vq_adaptor_layers_3(hidden)
        hidden = self.vq_adaptor_layers_4(hidden)
        return hidden


@MULTIMODAL_REGISTRY.register_processor(
    KimiAudioMultiModalProcessor,
    info=KimiAudioProcessingInfo,
    dummy_inputs=KimiAudioDummyInputsBuilder,
)
class KimiAudioForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    SupportsTranscription,
):
    """Kimi-Audio model for ASR transcription."""

    # Kimi-Audio supports a subset of Whisper's supported languages
    supported_languages: ClassVar[Mapping[str, str]] = {
        k: ISO639_1_SUPPORTED_LANGS[k]
        for k in ["zh", "en", "ja", "ko", "de", "fr", "es", "it", "pt", "ru", "ar"]
    }
    supports_transcription: ClassVar[Literal[True]] = True

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Audio projector (VQ-Adaptor)
            "model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
            "model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
            "model.vq_adaptor.layers.4.": "multi_modal_projector.vq_adaptor_layers_4.",
            # Language model
            "model.layers.": "language_model.model.layers.",
            # Embeddings and output
            "model.embed_tokens.": "language_model.model.embed_tokens.",
            "model.norm.": "language_model.model.norm.",
            "lm_head.": "language_model.lm_head.",
        }
    )

    # Audio placeholder token sequence
    AUDIO_PLACEHOLDER = "<|im_media_begin|><|im_kimia_text_blank|><|im_media_end|>"

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        return cls.AUDIO_PLACEHOLDER if modality.startswith("audio") else None

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
        self.multimodal_config = vllm_config.model_config.multimodal_config
        self.model_path = vllm_config.model_config.model

        self.audio_tower = KimiAudioWhisperEncoder(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "audio_tower"),
        )

        self.multi_modal_projector = KimiAudioMultiModalProjector(
            whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
            llm_dim=self.config.hidden_size,
            prefix=maybe_prefix(prefix, "multi_modal_projector"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config.with_hf_config(
                self.config, architectures=["Qwen2ForCausalLM"]
            ),
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.logits_processor = LogitsProcessor(
            self.config.vocab_size,
            self.config.vocab_size,
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> dict[str, torch.Tensor] | None:
        whisper_input_features = kwargs.pop("whisper_input_features", None)
        if whisper_input_features is None:
            return None

        return {"whisper_input_features": whisper_input_features}

    def _process_audio_input(
        self, audio_input: dict[str, torch.Tensor]
    ) -> torch.Tensor:
        input_features = audio_input["whisper_input_features"]

        # KimiAudioWhisperEncoder expects list of tensors
        if input_features.dim() == 3:
            input_features = input_features.unbind(dim=0)

        # Run through Whisper encoder
        audio_features = self.audio_tower(input_features)

        # Reshape for 4x downsampling (Whisper outputs at 50Hz, need 12.5Hz)
        B, T, D = audio_features.shape
        if T % 4 != 0:
            pad_len = 4 - (T % 4)
            audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len))
            T = audio_features.shape[1]  # Update T after padding

        audio_features = audio_features.reshape(B, T // 4, D * 4)

        # Project to LLM dimension
        audio_embeds = self.multi_modal_projector(audio_features)
        return audio_embeds

    def embed_multimodal(self, **kwargs: object) -> list[torch.Tensor] | None:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return []

        audio_embeds = self._process_audio_input(audio_input)

        # audio_embeds shape: [batch_size, seq_len, hidden_dim]
        # Return as list of 2D tensors, one per batch item
        if audio_embeds.dim() == 3:
            # Unbind batch dimension: [B, T, D] -> list of B tensors [T, D]
            return list(audio_embeds.unbind(dim=0))
        else:
            # Single sample: [T, D] -> wrap in list
            return [audio_embeds]

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: tuple[torch.Tensor, ...] | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        """Embed input IDs and fuse with audio embeddings.

        Kimi-Audio fusion: inputs_embeds = (text_emb + audio_emb) × √2

        For PP compatibility, we use the is_multimodal mask from vLLM engine
        which is correctly computed per pipeline stage.
        """
        # Get text embeddings
        inputs_embeds = self.language_model.model.embed_tokens(input_ids)

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        # is_multimodal must be provided for PP to work correctly
        if is_multimodal is None or not is_multimodal.any():
            return inputs_embeds

        # multimodal_embeddings[0] contains audio embeddings
        audio_embeds = multimodal_embeddings[0]

        # Handle different tensor structures
        if isinstance(audio_embeds, (list, tuple)):
            audio_embeds = torch.cat(audio_embeds, dim=0)
        elif audio_embeds.dim() == 3:
            audio_embeds = audio_embeds.reshape(-1, audio_embeds.shape[-1])

        # In PP, audio_embeds count should match is_multimodal.sum()
        # For now, use embeddings sequentially
        # (works for non-PP, PP needs vLLM infra fix)
        num_mm_tokens = is_multimodal.sum().item()
        num_audio_embeds = audio_embeds.shape[0]

        # Use the minimum of available embeddings and positions
        # This ensures we don't access out-of-bounds
        num_to_use = min(num_audio_embeds, num_mm_tokens)

        # Get positions for the tokens we'll actually process
        mm_positions = is_multimodal.nonzero(as_tuple=True)[0]
        actual_mm_mask = torch.zeros_like(is_multimodal)
        actual_mm_mask[mm_positions[:num_to_use]] = True

        # Use corresponding embeddings
        used_audio_embeds = audio_embeds[:num_to_use]

        # Save text embeddings at multimodal positions
        text_at_mm_positions = inputs_embeds[actual_mm_mask].clone()

        # Replace text with audio at multimodal positions
        inputs_embeds[actual_mm_mask] = used_audio_embeds.to(dtype=inputs_embeds.dtype)

        # Apply Kimi-Audio's unique fusion formula: (text + audio) × √2
        inputs_embeds[actual_mm_mask] = (
            inputs_embeds[actual_mm_mask] + text_at_mm_positions
        ) * (2**0.5)

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata | None = None,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata
        )
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights, skipping MIMO layers (TTS-only) for ASR."""
        # Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
        skipped_patterns = [
            "mimo_layers.",
            "mimo_output.",
            "mimo_norm.",
            "audio_decoder.",
        ]

        # Filter weights
        filtered_weights = [
            (name, param)
            for name, param in weights
            if not any(pattern in name for pattern in skipped_patterns)
        ]

        # Separate main weights (non-Whisper) from Whisper weights
        main_weights = [
            (name, param)
            for name, param in filtered_weights
            if not name.startswith("audio_tower.")
        ]

        # Load main model weights (LLM + projector) with mapper
        loader = AutoWeightsLoader(self)
        loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)

        # Load Whisper encoder weights from subfolder
        whisper_path = os.path.join(
            self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
        )
        if os.path.exists(whisper_path):
            whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
            loaded.update(whisper_loaded)

        return loaded

    def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
        """Load Whisper encoder weights from safetensors file with transformations."""
        if not os.path.exists(whisper_path):
            return set()

        # Step 1: Load raw weights from safetensors file
        whisper_weights = []
        with safe_open(whisper_path, framework="pt") as f:
            for key in f.keys():  # noqa: SIM118
                if key.startswith("model.encoder.") and "embed_positions" not in key:
                    new_key = key.replace("model.encoder.", "")
                    whisper_weights.append((new_key, f.get_tensor(key)))

        # Step 2: Apply fc → mlp mapping using WeightsMapper
        fc_mapper = WeightsMapper(
            orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
        )
        whisper_mapped = list(fc_mapper.apply(whisper_weights))

        # Step 3: Apply Q/K/V fusion manually
        stacked_params_mapping = [
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
        ]

        params_dict = dict(self.audio_tower.named_parameters())
        whisper_loaded: set[str] = set()

        for name, loaded_weight in whisper_mapped:
            fused = False
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                fused_name = name.replace(weight_name, param_name)
                if fused_name not in params_dict:
                    continue

                param = params_dict[fused_name]
                param.weight_loader(param, loaded_weight, shard_id)
                whisper_loaded.add(f"audio_tower.{fused_name}")
                fused = True
                break

            if not fused:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
                whisper_loaded.add(f"audio_tower.{name}")

        # Add embed_positions which is initialized randomly
        whisper_loaded.add("audio_tower.embed_positions.weight")

        return whisper_loaded

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        """Get speech-to-text config with custom processor."""
        # Load feature extractor for config values
        feature_extractor = cached_feature_extractor_from_config(
            model_config,
            subfolder=KIMIA_WHISPER_SUBFOLDER,
        )

        return SpeechToTextConfig(
            max_audio_clip_s=feature_extractor.chunk_length,
            sample_rate=feature_extractor.sampling_rate,
        )

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            tokenizer_cls=KimiAudioTokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            revision=model_config.tokenizer_revision,
            trust_remote_code=model_config.trust_remote_code,
        )

        if task_type not in ("transcribe", "translate"):
            raise ValueError(
                f"Unsupported task_type '{task_type}'. "
                "Supported task types are 'transcribe' and 'translate'."
            )

        # Incorporate request_prompt as context/instruction if provided
        user_content = (
            f"{request_prompt}\n{cls.AUDIO_PLACEHOLDER}"
            if request_prompt
            else cls.AUDIO_PLACEHOLDER
        )

        prompt = (
            f"<|im_kimia_user_msg_start|>{user_content}"
            f"<|im_msg_end|><|im_kimia_assistant_msg_start|>"
        )

        prompt_token_ids = tokenizer.encode(prompt)

        return TokensPrompt(
            prompt_token_ids=prompt_token_ids,
            multi_modal_data={"audio": audio},
        )

    @classmethod
    def post_process_output(cls, text: str) -> str:
        if not text:
            return ""
        return text.strip()