Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8bfc8d56
Unverified
Commit
8bfc8d56
authored
Jan 30, 2026
by
Isotr0py
Committed by
GitHub
Jan 30, 2026
Browse files
[Models] Refactor Kimi-K2.5 weight loading (#33346)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
ec51831a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
176 deletions
+40
-176
vllm/model_executor/models/kimi_k25.py
vllm/model_executor/models/kimi_k25.py
+38
-174
vllm/model_executor/models/kimi_k25_vit.py
vllm/model_executor/models/kimi_k25_vit.py
+2
-2
No files found.
vllm/model_executor/models/kimi_k25.py
View file @
8bfc8d56
...
...
@@ -23,16 +23,7 @@ from transformers.processing_utils import ProcessorMixin
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV2Model
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
,
SupportsPP
from
vllm.model_executor.models.kimi_k25_vit
import
(
KimiK25MultiModalProjector
,
...
...
@@ -64,7 +55,12 @@ from vllm.transformers_utils.configs import KimiK25Config
from
vllm.transformers_utils.processor
import
cached_get_image_processor
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
...
...
@@ -294,6 +290,13 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
supports_encoder_tp_data
=
True
weights_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"mm_projector.proj.0"
:
"mm_projector.linear_1"
,
"mm_projector.proj.2"
:
"mm_projector.linear_2"
,
}
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
# Kimi-K2.5 uses video_chunk for all media types
...
...
@@ -323,6 +326,7 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
device
=
current_platform
.
current_device
()
# Build vision tower directly with KimiK25VisionConfig
with
self
.
_mark_tower_model
(
vllm_config
,
"vision_chunk"
):
self
.
vision_tower
=
MoonViT3dPretrainedModel
(
config
.
vision_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
...
...
@@ -345,23 +349,16 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
sub_vllm_config
.
model_config
.
hf_config
=
(
sub_vllm_config
.
model_config
.
hf_config
.
text_config
)
self
.
language_model
=
DeepseekV2Model
(
vllm_config
=
sub_vllm_config
,
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"DeepseekV2ForCausalLM"
],
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
text_config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
logit_scale
)
self
.
media_placeholder
:
int
=
self
.
config
.
media_placeholder_token_id
def
_parse_and_validate_media_input
(
...
...
@@ -421,9 +418,6 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
vision_embeddings
=
self
.
_process_media_input
(
media_input
)
return
vision_embeddings
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -444,139 +438,9 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
logits
=
self
.
l
ogits_processor
(
self
.
lm_head
,
hidden_states
,
**
kwargs
)
logits
=
self
.
l
anguage_model
.
compute_logits
(
hidden_states
)
return
logits
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
config
=
self
.
config
.
text_config
if
not
getattr
(
config
,
"n_routed_experts"
,
None
):
return
[]
return
SharedFusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
config
.
n_routed_experts
,
num_redundant_experts
=
0
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
config
=
self
.
config
.
text_config
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
# mm_projector -> mm_projector mapping
# "mm_projector": "mm_projector",
"mm_projector.proj.0"
:
"mm_projector.linear_1"
,
"mm_projector.proj.2"
:
"mm_projector.linear_2"
,
}
stacked_params_mapping
=
[
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
if
getattr
(
config
,
"kv_lora_rank"
,
None
)
and
getattr
(
config
,
"q_lora_rank"
,
None
):
stacked_params_mapping
+=
[
(
".fused_qkv_a_proj"
,
".q_a_proj"
,
0
),
(
".fused_qkv_a_proj"
,
".kv_a_proj_with_mqa"
,
1
),
]
expert_params_mapping
=
self
.
get_expert_mapping
()
params_dict
=
dict
(
self
.
named_parameters
())
for
args
in
weights
:
name
,
loaded_weight
=
args
[:
2
]
kwargs
=
args
[
2
]
if
len
(
args
)
>
2
else
{}
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
config
,
name
)
if
spec_layer
is
not
None
:
continue
# skip spec decode layers for main model
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
use_default_weight_loading
=
True
else
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
,
**
kwargs
)
break
else
:
for
_
,
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
enumerate
(
expert_params_mapping
):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
expert_id
=
expert_id
,
shard_id
=
shard_id
,
**
kwargs
,
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
**
kwargs
)
def
get_spec_layer_idx_from_weight_name
(
config
:
KimiK25Config
,
weight_name
:
str
)
->
int
|
None
:
if
hasattr
(
config
,
"num_nextn_predict_layers"
)
and
(
config
.
num_nextn_predict_layers
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_nextn_predict_layers
):
# might start with language_model.model.layers
if
f
"model.layers.
{
layer_idx
+
i
}
."
in
weight_name
:
return
layer_idx
+
i
return
None
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
weights_mapper
)
vllm/model_executor/models/kimi_k25_vit.py
View file @
8bfc8d56
...
...
@@ -660,13 +660,13 @@ class KimiK25MultiModalProjector(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
prefix
=
maybe_prefix
(
prefix
,
"
linear_1"
)
,
prefix
=
f
"
{
prefix
}
.
linear_1"
,
)
self
.
linear_2
=
ReplicatedLinear
(
self
.
hidden_size
,
config
.
mm_hidden_size
,
bias
=
True
,
prefix
=
maybe_prefix
(
prefix
,
"
linear_2"
)
,
prefix
=
f
"
{
prefix
}
.
linear_2"
,
)
self
.
act
=
GELUActivation
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment