Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8e8d62d
Unverified
Commit
a8e8d62d
authored
Mar 14, 2026
by
Isotr0py
Committed by
GitHub
Mar 14, 2026
Browse files
[Misc] Clean up Kimi-audio whisper encoder loading (#36903)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
e42b49bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
116 deletions
+89
-116
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+15
-4
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+13
-1
vllm/model_executor/models/kimi_audio.py
vllm/model_executor/models/kimi_audio.py
+61
-111
No files found.
vllm/model_executor/model_loader/default_loader.py
View file @
a8e8d62d
...
...
@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
revision
:
str
|
None
"""The optional model revision."""
subfolder
:
str
|
None
=
None
"""The subfolder inside the model repo."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
...
...
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
subfolder
:
str
|
None
,
revision
:
str
|
None
,
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
list
[
str
]
|
None
,
...
...
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
subfolder
=
subfolder
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
if
subfolder
is
not
None
:
hf_folder
=
os
.
path
.
join
(
hf_folder
,
subfolder
)
hf_weights_files
:
list
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
...
...
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
cache_dir
=
self
.
load_config
.
download_dir
,
subfolder
=
subfolder
,
revision
=
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
...
...
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
extra_config
=
self
.
load_config
.
model_loader_extra_config
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
subfolder
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
,
...
...
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
model_name_or_path
=
model_config
.
model
,
subfolder
=
None
,
revision
=
model_config
.
revision
,
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
,
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
a8e8d62d
...
...
@@ -472,6 +472,7 @@ def download_weights_from_hf(
cache_dir
:
str
|
None
,
allow_patterns
:
list
[
str
],
revision
:
str
|
None
=
None
,
subfolder
:
str
|
None
=
None
,
ignore_patterns
:
str
|
list
[
str
]
|
None
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
...
...
@@ -484,6 +485,8 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
...
...
@@ -498,7 +501,11 @@ def download_weights_from_hf(
# so we only have to call snapshot_download once.
try
:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
file_list
=
fs
.
ls
(
os
.
path
.
join
(
model_name_or_path
,
subfolder
or
""
),
detail
=
False
,
revision
=
revision
,
)
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
...
...
@@ -510,6 +517,7 @@ def download_weights_from_hf(
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
revision
=
revision
,
subfolder
=
subfolder
,
)
with
open
(
index_path
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
...
...
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
str
|
None
,
subfolder
:
str
|
None
=
None
,
revision
:
str
|
None
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
...
...
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
index_file (str): The safetensors index file name
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
...
...
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
filename
=
index_file
,
cache_dir
=
cache_dir
,
revision
=
revision
,
subfolder
=
subfolder
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
...
...
vllm/model_executor/models/kimi_audio.py
View file @
a8e8d62d
...
...
@@ -3,15 +3,12 @@
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
import
os
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
ClassVar
,
Literal
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
huggingface_hub
import
snapshot_download
from
safetensors
import
safe_open
from
transformers
import
BatchFeature
from
transformers
import
WhisperConfig
as
HFWhisperConfig
...
...
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
)
from
vllm.model_executor.model_loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
SupportsPP
,
...
...
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER
=
"whisper-large-v3"
def
_get_whisper_local_path
(
repo_id
:
str
):
if
os
.
path
.
exists
(
repo_id
):
repo_local_path
=
repo_id
else
:
repo_local_path
=
snapshot_download
(
repo_id
,
local_files_only
=
True
)
return
os
.
path
.
join
(
repo_local_path
,
KIMIA_WHISPER_SUBFOLDER
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute output lengths after Whisper feature extraction.
...
...
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# packed_modules_mapping for Q/K/V fusion during weight loading
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"kv_proj"
:
[
"k_proj"
,
"v_proj"
],
}
def
__init__
(
...
...
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
model_path
=
vllm_config
.
model_config
.
model
# Load WhisperConfig from the subfolder
whisper_dir
=
_get_whisper_local_path
(
model_path
)
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
whisper_dir
)
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config
=
vllm_config
.
model_config
.
hf_config
vllm_config
.
model_config
.
hf_config
=
whisper_config
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
model_path
,
subfolder
=
KIMIA_WHISPER_SUBFOLDER
,
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
init_in_fp32
=
init_in_fp32
vllm_config
=
vllm_config
.
with_hf_config
(
whisper_config
),
prefix
=
prefix
,
init_in_fp32
=
init_in_fp32
,
)
# Restore original config
vllm_config
.
model_config
.
hf_config
=
original_config
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# -----------------------------------------------------------------------------
...
...
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
# audio tower
"model.encoder."
:
"audio_tower."
,
# Audio projector (VQ-Adaptor)
"model.vq_adaptor.layers.0."
:
"multi_modal_projector.vq_adaptor_layers_0."
,
"model.vq_adaptor.layers.3."
:
"multi_modal_projector.vq_adaptor_layers_3."
,
...
...
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
"model.embed_tokens."
:
"language_model.model.embed_tokens."
,
"model.norm."
:
"language_model.model.norm."
,
"lm_head."
:
"language_model.lm_head."
,
}
},
orig_to_new_substr
=
{
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
,
},
)
# Audio placeholder token sequence
...
...
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
self
.
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
model_path
=
vllm_config
.
model_config
.
model
self
.
secondary_weights
=
[
DefaultModelLoader
.
Source
(
model_or_path
=
vllm_config
.
model_config
.
model
,
subfolder
=
"whisper-large-v3"
,
revision
=
None
,
)
]
self
.
audio_tower
=
KimiAudioWhisperEncoder
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
...
...
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
skipped_patterns
=
[
# Audio tower
"model."
,
# MIMO/TTS
"mimo_layers."
,
"mimo_output."
,
"mimo_norm."
,
"audio_decoder."
,
]
# Filter weights
filtered_weights
=
[
(
name
,
param
)
for
name
,
param
in
weights
if
not
any
(
pattern
in
name
for
pattern
in
skipped_patterns
)
]
# Separate main weights (non-Whisper) from Whisper weights
main_weights
=
[
(
name
,
param
)
for
name
,
param
in
filtered_weights
if
not
name
.
startswith
(
"audio_tower."
)
]
# Load main model weights (LLM + projector) with mapper
loader
=
AutoWeightsLoader
(
self
)
loaded
=
loader
.
load_weights
(
main_weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
# Load Whisper encoder weights from subfolder
whisper_dir
=
_get_whisper_local_path
(
self
.
model_path
)
whisper_path
=
os
.
path
.
join
(
whisper_dir
,
"model.safetensors"
)
if
os
.
path
.
exists
(
whisper_path
):
whisper_loaded
=
self
.
_load_whisper_weights_from_file
(
whisper_path
)
loaded
.
update
(
whisper_loaded
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skipped_patterns
)
loaded
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loaded
def
_load_whisper_weights_from_file
(
self
,
whisper_path
:
str
)
->
set
[
str
]:
"""Load Whisper encoder weights from safetensors file with transformations."""
if
not
os
.
path
.
exists
(
whisper_path
):
return
set
()
# Step 1: Load raw weights from safetensors file
whisper_weights
=
[]
with
safe_open
(
whisper_path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
if
key
.
startswith
(
"model.encoder."
)
and
"embed_positions"
not
in
key
:
new_key
=
key
.
replace
(
"model.encoder."
,
""
)
whisper_weights
.
append
((
new_key
,
f
.
get_tensor
(
key
)))
# Step 2: Apply fc → mlp mapping using WeightsMapper
fc_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
}
)
whisper_mapped
=
list
(
fc_mapper
.
apply
(
whisper_weights
))
# Step 3: Apply Q/K/V fusion manually
stacked_params_mapping
=
[
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
audio_tower
.
named_parameters
())
whisper_loaded
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
whisper_mapped
:
fused
=
False
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
fused_name
=
name
.
replace
(
weight_name
,
param_name
)
if
fused_name
not
in
params_dict
:
continue
param
=
params_dict
[
fused_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
whisper_loaded
.
add
(
f
"audio_tower.
{
fused_name
}
"
)
fused
=
True
break
if
not
fused
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
whisper_loaded
.
add
(
f
"audio_tower.
{
name
}
"
)
# Add embed_positions which is initialized randomly
whisper_loaded
.
add
(
"audio_tower.embed_positions.weight"
)
return
whisper_loaded
@
classmethod
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
task_type
:
str
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment