Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8e8d62d
Unverified
Commit
a8e8d62d
authored
Mar 14, 2026
by
Isotr0py
Committed by
GitHub
Mar 14, 2026
Browse files
[Misc] Clean up Kimi-audio whisper encoder loading (#36903)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
e42b49bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
116 deletions
+89
-116
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+15
-4
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+13
-1
vllm/model_executor/models/kimi_audio.py
vllm/model_executor/models/kimi_audio.py
+61
-111
No files found.
vllm/model_executor/model_loader/default_loader.py
View file @
a8e8d62d
...
@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
revision
:
str
|
None
revision
:
str
|
None
"""The optional model revision."""
"""The optional model revision."""
subfolder
:
str
|
None
=
None
"""The subfolder inside the model repo."""
prefix
:
str
=
""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
"""A prefix to prepend to all weights."""
...
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
def
_prepare_weights
(
def
_prepare_weights
(
self
,
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
subfolder
:
str
|
None
,
revision
:
str
|
None
,
revision
:
str
|
None
,
fall_back_to_pt
:
bool
,
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
list
[
str
]
|
None
,
allow_patterns_overrides
:
list
[
str
]
|
None
,
...
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
.
download_dir
,
self
.
load_config
.
download_dir
,
allow_patterns
,
allow_patterns
,
revision
,
revision
,
subfolder
=
subfolder
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
if
subfolder
is
not
None
:
hf_folder
=
os
.
path
.
join
(
hf_folder
,
subfolder
)
hf_weights_files
:
list
[
str
]
=
[]
hf_weights_files
:
list
[
str
]
=
[]
for
pattern
in
allow_patterns
:
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
...
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
download_safetensors_index_file_from_hf
(
download_safetensors_index_file_from_hf
(
model_name_or_path
,
model_name_or_path
,
index_file
,
index_file
,
self
.
load_config
.
download_dir
,
cache_dir
=
self
.
load_config
.
download_dir
,
revision
,
subfolder
=
subfolder
,
revision
=
revision
,
)
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
hf_weights_files
,
hf_folder
,
index_file
...
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
extra_config
=
self
.
load_config
.
model_loader_extra_config
extra_config
=
self
.
load_config
.
model_loader_extra_config
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
model_or_path
,
source
.
subfolder
,
source
.
revision
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
,
source
.
allow_patterns_overrides
,
...
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
self
.
_prepare_weights
(
model_config
.
model
,
model_name_or_path
=
model_config
.
model
,
model_config
.
revision
,
subfolder
=
None
,
revision
=
model_config
.
revision
,
fall_back_to_pt
=
True
,
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
,
allow_patterns_overrides
=
None
,
)
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
a8e8d62d
...
@@ -472,6 +472,7 @@ def download_weights_from_hf(
...
@@ -472,6 +472,7 @@ def download_weights_from_hf(
cache_dir
:
str
|
None
,
cache_dir
:
str
|
None
,
allow_patterns
:
list
[
str
],
allow_patterns
:
list
[
str
],
revision
:
str
|
None
=
None
,
revision
:
str
|
None
=
None
,
subfolder
:
str
|
None
=
None
,
ignore_patterns
:
str
|
list
[
str
]
|
None
=
None
,
ignore_patterns
:
str
|
list
[
str
]
|
None
=
None
,
)
->
str
:
)
->
str
:
"""Download model weights from Hugging Face Hub.
"""Download model weights from Hugging Face Hub.
...
@@ -484,6 +485,8 @@ def download_weights_from_hf(
...
@@ -484,6 +485,8 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
weight files. Files matched by any of the patterns will be
downloaded.
downloaded.
revision (Optional[str]): The revision of the model.
revision (Optional[str]): The revision of the model.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
filter out the weight files. Files matched by any of the patterns
will be ignored.
will be ignored.
...
@@ -498,7 +501,11 @@ def download_weights_from_hf(
...
@@ -498,7 +501,11 @@ def download_weights_from_hf(
# so we only have to call snapshot_download once.
# so we only have to call snapshot_download once.
try
:
try
:
fs
=
HfFileSystem
()
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
file_list
=
fs
.
ls
(
os
.
path
.
join
(
model_name_or_path
,
subfolder
or
""
),
detail
=
False
,
revision
=
revision
,
)
# If downloading safetensors and an index file exists, use the
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
# specific file names from the index to avoid downloading
...
@@ -510,6 +517,7 @@ def download_weights_from_hf(
...
@@ -510,6 +517,7 @@ def download_weights_from_hf(
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
)
)
with
open
(
index_path
)
as
f
:
with
open
(
index_path
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
...
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
...
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
index_file
:
str
,
index_file
:
str
,
cache_dir
:
str
|
None
,
cache_dir
:
str
|
None
,
subfolder
:
str
|
None
=
None
,
revision
:
str
|
None
=
None
,
revision
:
str
|
None
=
None
,
)
->
None
:
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
"""Download hf safetensors index file from Hugging Face Hub.
...
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
...
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
index_file (str): The safetensors index file name
index_file (str): The safetensors index file name
cache_dir (Optional[str]): The cache directory to store the model
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
weights. If None, will use HF defaults.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
revision (Optional[str]): The revision of the model.
revision (Optional[str]): The revision of the model.
"""
"""
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
...
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
...
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
filename
=
index_file
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
)
# If file not found on remote or locally, we should not fail since
# If file not found on remote or locally, we should not fail since
...
...
vllm/model_executor/models/kimi_audio.py
View file @
a8e8d62d
...
@@ -3,15 +3,12 @@
...
@@ -3,15 +3,12 @@
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
import
os
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
ClassVar
,
Literal
from
typing
import
Any
,
ClassVar
,
Literal
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
huggingface_hub
import
snapshot_download
from
safetensors
import
safe_open
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
transformers
import
WhisperConfig
as
HFWhisperConfig
from
transformers
import
WhisperConfig
as
HFWhisperConfig
...
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
...
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader
import
DefaultModelLoader
default_weight_loader
,
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
,
SupportsPP
,
...
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
...
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER
=
"whisper-large-v3"
KIMIA_WHISPER_SUBFOLDER
=
"whisper-large-v3"
def
_get_whisper_local_path
(
repo_id
:
str
):
if
os
.
path
.
exists
(
repo_id
):
repo_local_path
=
repo_id
else
:
repo_local_path
=
snapshot_download
(
repo_id
,
local_files_only
=
True
)
return
os
.
path
.
join
(
repo_local_path
,
KIMIA_WHISPER_SUBFOLDER
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute output lengths after Whisper feature extraction.
"""Compute output lengths after Whisper feature extraction.
...
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
...
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# packed_modules_mapping for Q/K/V fusion during weight loading
# packed_modules_mapping for Q/K/V fusion during weight loading
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"kv_proj"
:
[
"k_proj"
,
"v_proj"
],
}
}
def
__init__
(
def
__init__
(
...
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
...
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
model_path
=
vllm_config
.
model_config
.
model
model_path
=
vllm_config
.
model_config
.
model
# Load WhisperConfig from the subfolder
# Load WhisperConfig from the subfolder
whisper_dir
=
_get_whisper_local_path
(
model_path
)
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
whisper_dir
)
model_path
,
subfolder
=
KIMIA_WHISPER_SUBFOLDER
,
# Temporarily replace hf_config for WhisperEncoder.__init__()
)
original_config
=
vllm_config
.
model_config
.
hf_config
vllm_config
.
model_config
.
hf_config
=
whisper_config
super
().
__init__
(
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
init_in_fp32
=
init_in_fp32
vllm_config
=
vllm_config
.
with_hf_config
(
whisper_config
),
prefix
=
prefix
,
init_in_fp32
=
init_in_fp32
,
)
)
# Restore original config
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
vllm_config
.
model_config
.
hf_config
=
original_config
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
...
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
...
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
orig_to_new_prefix
=
{
# audio tower
"model.encoder."
:
"audio_tower."
,
# Audio projector (VQ-Adaptor)
# Audio projector (VQ-Adaptor)
"model.vq_adaptor.layers.0."
:
"multi_modal_projector.vq_adaptor_layers_0."
,
"model.vq_adaptor.layers.0."
:
"multi_modal_projector.vq_adaptor_layers_0."
,
"model.vq_adaptor.layers.3."
:
"multi_modal_projector.vq_adaptor_layers_3."
,
"model.vq_adaptor.layers.3."
:
"multi_modal_projector.vq_adaptor_layers_3."
,
...
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
...
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
"model.embed_tokens."
:
"language_model.model.embed_tokens."
,
"model.embed_tokens."
:
"language_model.model.embed_tokens."
,
"model.norm."
:
"language_model.model.norm."
,
"model.norm."
:
"language_model.model.norm."
,
"lm_head."
:
"language_model.lm_head."
,
"lm_head."
:
"language_model.lm_head."
,
}
},
orig_to_new_substr
=
{
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
,
},
)
)
# Audio placeholder token sequence
# Audio placeholder token sequence
...
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
...
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
self
.
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
model_path
=
vllm_config
.
model_config
.
model
self
.
model_path
=
vllm_config
.
model_config
.
model
self
.
secondary_weights
=
[
DefaultModelLoader
.
Source
(
model_or_path
=
vllm_config
.
model_config
.
model
,
subfolder
=
"whisper-large-v3"
,
revision
=
None
,
)
]
self
.
audio_tower
=
KimiAudioWhisperEncoder
(
self
.
audio_tower
=
KimiAudioWhisperEncoder
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
...
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
...
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
skipped_patterns
=
[
skipped_patterns
=
[
# Audio tower
"model."
,
# MIMO/TTS
"mimo_layers."
,
"mimo_layers."
,
"mimo_output."
,
"mimo_output."
,
"mimo_norm."
,
"mimo_norm."
,
"audio_decoder."
,
]
# Filter weights
filtered_weights
=
[
(
name
,
param
)
for
name
,
param
in
weights
if
not
any
(
pattern
in
name
for
pattern
in
skipped_patterns
)
]
# Separate main weights (non-Whisper) from Whisper weights
main_weights
=
[
(
name
,
param
)
for
name
,
param
in
filtered_weights
if
not
name
.
startswith
(
"audio_tower."
)
]
]
# Load main model weights (LLM + projector) with mapper
# Load main model weights (LLM + projector) with mapper
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skipped_patterns
)
loaded
=
loader
.
load_weights
(
main_weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
loaded
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
# Load Whisper encoder weights from subfolder
whisper_dir
=
_get_whisper_local_path
(
self
.
model_path
)
whisper_path
=
os
.
path
.
join
(
whisper_dir
,
"model.safetensors"
)
if
os
.
path
.
exists
(
whisper_path
):
whisper_loaded
=
self
.
_load_whisper_weights_from_file
(
whisper_path
)
loaded
.
update
(
whisper_loaded
)
return
loaded
return
loaded
def
_load_whisper_weights_from_file
(
self
,
whisper_path
:
str
)
->
set
[
str
]:
"""Load Whisper encoder weights from safetensors file with transformations."""
if
not
os
.
path
.
exists
(
whisper_path
):
return
set
()
# Step 1: Load raw weights from safetensors file
whisper_weights
=
[]
with
safe_open
(
whisper_path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
if
key
.
startswith
(
"model.encoder."
)
and
"embed_positions"
not
in
key
:
new_key
=
key
.
replace
(
"model.encoder."
,
""
)
whisper_weights
.
append
((
new_key
,
f
.
get_tensor
(
key
)))
# Step 2: Apply fc → mlp mapping using WeightsMapper
fc_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
}
)
whisper_mapped
=
list
(
fc_mapper
.
apply
(
whisper_weights
))
# Step 3: Apply Q/K/V fusion manually
stacked_params_mapping
=
[
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
audio_tower
.
named_parameters
())
whisper_loaded
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
whisper_mapped
:
fused
=
False
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
fused_name
=
name
.
replace
(
weight_name
,
param_name
)
if
fused_name
not
in
params_dict
:
continue
param
=
params_dict
[
fused_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
whisper_loaded
.
add
(
f
"audio_tower.
{
fused_name
}
"
)
fused
=
True
break
if
not
fused
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
whisper_loaded
.
add
(
f
"audio_tower.
{
name
}
"
)
# Add embed_positions which is initialized randomly
whisper_loaded
.
add
(
"audio_tower.embed_positions.weight"
)
return
whisper_loaded
@
classmethod
@
classmethod
def
get_speech_to_text_config
(
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
task_type
:
str
cls
,
model_config
:
ModelConfig
,
task_type
:
str
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment