Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d76fc11e
Commit
d76fc11e
authored
Jan 28, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev
parents
38166ec4
58996f35
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1283 additions
and
24 deletions
+1283
-24
vllm/model_executor/models/kimi_k25.py
vllm/model_executor/models/kimi_k25.py
+581
-0
vllm/model_executor/models/kimi_k25_vit.py
vllm/model_executor/models/kimi_k25_vit.py
+678
-0
vllm/model_executor/models/kimi_linear.py
vllm/model_executor/models/kimi_linear.py
+1
-1
vllm/model_executor/models/kimi_vl.py
vllm/model_executor/models/kimi_vl.py
+1
-1
vllm/model_executor/models/lfm2.py
vllm/model_executor/models/lfm2.py
+2
-2
vllm/model_executor/models/lfm2_moe.py
vllm/model_executor/models/lfm2_moe.py
+2
-2
vllm/model_executor/models/lfm2_vl.py
vllm/model_executor/models/lfm2_vl.py
+1
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+1
-1
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+1
-1
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+1
-1
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+1
-1
vllm/model_executor/models/longcat_flash.py
vllm/model_executor/models/longcat_flash.py
+2
-2
vllm/model_executor/models/longcat_flash_mtp.py
vllm/model_executor/models/longcat_flash_mtp.py
+1
-1
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+2
-2
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+2
-2
vllm/model_executor/models/midashenglm.py
vllm/model_executor/models/midashenglm.py
+1
-1
vllm/model_executor/models/mimo.py
vllm/model_executor/models/mimo.py
+1
-1
vllm/model_executor/models/mimo_mtp.py
vllm/model_executor/models/mimo_mtp.py
+1
-1
vllm/model_executor/models/mimo_v2_flash.py
vllm/model_executor/models/mimo_v2_flash.py
+2
-2
No files found.
vllm/model_executor/models/kimi_k25.py
0 → 100644
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Kimi-K2.5 Model Implementation for vLLM.
Kimi-K2.5 extends Kimi-K2 with vision support
This module defines:
- KimiK25ProcessingInfo/KimiK25MultiModalProcessor: Processing logic
- KimiK25ForConditionalGeneration: Main model class
"""
import
copy
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
typing
import
Annotated
,
Any
,
Literal
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
from
transformers.processing_utils
import
ProcessorMixin
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV2Model
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
,
SupportsPP
from
vllm.model_executor.models.kimi_k25_vit
import
(
KimiK25MultiModalProjector
,
MoonViT3dPretrainedModel
,
vision_tower_forward
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
NestedTensors
,
VisionChunk
,
VisionChunkImage
,
VisionChunkVideo
,
)
from
vllm.multimodal.parse
import
MultiModalDataItems
,
VisionChunkProcessorItems
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
InputProcessingContext
,
PromptReplacement
,
PromptUpdate
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
KimiK25Config
from
vllm.transformers_utils.processor
import
cached_get_image_processor
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
logger
=
init_logger
(
__name__
)
# Dummy input dimensions for profiling.
@
dataclass
class
MaxImageTokenMeta
:
width
:
int
=
3000
height
:
int
=
3000
class
KimiK25MediaPixelInputs
(
TensorSchema
):
"""
Media input schema for K2-VL model.
Dimensions:
- np: Number of patches (flattened from all media items)
- ps: Patch size
- nm: Number of media items
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
pixel_values
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"np"
,
3
,
"ps"
,
"ps"
),
]
grid_thws
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nm"
,
3
)]
class
MoonshotKimiVAutoProcessor
(
ProcessorMixin
):
attributes
=
[
"tokenizer"
]
tokenizer_class
=
"AutoTokenizer"
def
__init__
(
self
,
media_processor
=
None
,
tokenizer
=
None
):
super
().
__init__
(
tokenizer
)
self
.
media_processor
=
media_processor
# We do not support str input for text here
def
__call__
(
self
,
vision_chunks
:
list
[
VisionChunk
]
|
None
=
None
,
*
,
text
:
list
[
int
],
**
kwargs
,
)
->
BatchFeature
:
"""
Args:
vision_chunks: List of VisionChunk items to be processed.
For image: VisionChunkImage with type='image', image=PIL.Image
For video_chunk: VisionChunkVideo with type='video_chunk', video_chunk=list[PIL.Image]
text: The token ids to be fed to a model (required).
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- list of token ids to be fed to a model.
- **pixel_values** -- Pixel values to be fed to a model. Returned when `vision_chunks` is not `None`.
- **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
"""
mm_inputs
=
{}
if
vision_chunks
is
not
None
:
assert
isinstance
(
vision_chunks
,
list
)
mm_inputs
=
self
.
media_processor
.
preprocess
(
vision_chunks
)
# XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
return
BatchFeature
(
data
=
{
"input_ids"
:
torch
.
tensor
([
text
]),
**
mm_inputs
,
}
)
class
KimiK25ProcessingInfo
(
BaseProcessingInfo
):
"""Processing information for Kimi-K2.5 model.
Provides configuration and utilities for processing both
images and video-chunks.
"""
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
(
ctx
)
self
.
hf_config
=
self
.
get_hf_config
()
self
.
media_token_id
=
self
.
hf_config
.
media_placeholder_token_id
media_processor
=
cached_get_image_processor
(
self
.
ctx
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
media_processor
=
media_processor
self
.
hf_processor
=
MoonshotKimiVAutoProcessor
(
media_processor
=
self
.
media_processor
,
tokenizer
=
self
.
get_tokenizer
(),
)
self
.
media_tokens_calculator
=
self
.
media_processor
.
media_tokens_calculator
def
get_hf_processor
(
self
):
return
self
.
hf_processor
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
KimiK25Config
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
# None means unlimited
return
{
"vision_chunk"
:
None
}
class
KimiK25DummyInputsBuilder
(
BaseDummyInputsBuilder
[
KimiK25ProcessingInfo
]):
"""Builds dummy inputs for Kimi-K2.5 model profiling."""
def
__init__
(
self
,
info
:
KimiK25ProcessingInfo
)
->
None
:
super
().
__init__
(
info
)
self
.
media_token_id
=
self
.
info
.
media_token_id
self
.
frame_per_chunk
=
self
.
info
.
media_processor
.
num_frames_per_chunk
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
list
[
int
]:
num_media
=
mm_counts
.
get
(
"vision_chunk"
,
0
)
return
[
self
.
media_token_id
]
*
num_media
def
get_dummy_mm_items
(
self
):
dummy_videos
=
self
.
_get_dummy_images
(
height
=
MaxImageTokenMeta
.
height
,
width
=
MaxImageTokenMeta
.
width
,
num_images
=
self
.
frame_per_chunk
,
)
video_chunk_dummy_item
=
VisionChunkVideo
(
type
=
"video_chunk"
,
video_chunk
=
dummy_videos
)
video_chunk_num_tokens
=
self
.
info
.
media_tokens_calculator
(
video_chunk_dummy_item
)
image_dummy_item
=
VisionChunkImage
(
type
=
"image"
,
image
=
self
.
_get_dummy_images
(
height
=
MaxImageTokenMeta
.
height
,
width
=
MaxImageTokenMeta
.
width
,
num_images
=
1
,
)[
0
],
)
image_num_tokens
=
self
.
info
.
media_tokens_calculator
(
image_dummy_item
)
# return the larger one
if
video_chunk_num_tokens
>=
image_num_tokens
:
return
[
video_chunk_dummy_item
]
else
:
return
[
image_dummy_item
]
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
MultiModalDataDict
:
# TODO: Support mm_options for vision_chunk to allow user configuration
dummy_items
=
self
.
get_dummy_mm_items
()
return
{
"vision_chunk"
:
dummy_items
}
class
KimiK25MultiModalProcessor
(
BaseMultiModalProcessor
[
KimiK25ProcessingInfo
]):
"""Multi-modal processor for Kimi-K2.5.
Handles both image and video-chunk modalities.
"""
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
"""Indicates how to slice media input into multiple items.
pixel_values: [N, 3, patch_size, patch_size], all patches collected from B medias
grid_thws: [B,3], each item: [N_t, N_h ,N_w], indicates the grid size in time/height/width direction
for current item.
by multiplying [N_t, N_h ,N_w], we get the number of patches for each media item, thus we can slice
pixel_values by pixel_values[start:start + N_t*N_h*N_w] to get patches of one item.
"""
grid_thws
=
hf_inputs
.
get
(
"grid_thws"
,
torch
.
empty
((
0
,
3
)))
grid_sizes
=
grid_thws
.
prod
(
-
1
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"vision_chunk"
,
grid_sizes
),
grid_thws
=
MultiModalFieldConfig
.
batched
(
"vision_chunk"
),
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
media_token_id
=
hf_config
.
media_placeholder_token_id
def
get_replacement
(
item_idx
:
int
):
media
=
mm_items
.
get_items
(
"vision_chunk"
,
(
VisionChunkProcessorItems
,))
num_media_token
=
self
.
info
.
media_tokens_calculator
(
media
[
item_idx
])
return
[
media_token_id
]
*
num_media_token
return
[
PromptReplacement
(
modality
=
"vision_chunk"
,
target
=
[
media_token_id
],
replacement
=
get_replacement
,
),
]
def
split_video_chunks
(
self
,
video
):
return
self
.
info
.
media_processor
.
split_video_chunks
(
video
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
KimiK25MultiModalProcessor
,
info
=
KimiK25ProcessingInfo
,
dummy_inputs
=
KimiK25DummyInputsBuilder
,
)
class
KimiK25ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
"""Kimi-K2.5 model for conditional generation.
Supports both image and video-chunk modalities.
Video-chunks are temporal segments (typically 4 frames) that are
processed with temporal pooling.
"""
supports_encoder_tp_data
=
True
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
# Kimi-K2.5 uses video_chunk for all media types
if
modality
==
"image"
:
return
"<|media_begin|>image<|media_content|><|media_pad|><|media_end|>"
elif
modality
==
"video"
:
# return a placeholder, to be replaced in the future.
return
"<|kimi_k25_video_placeholder|>"
raise
ValueError
(
f
"Unsupported modality:
{
modality
}
"
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
model_config
=
vllm_config
.
model_config
config
:
KimiK25Config
=
model_config
.
hf_config
self
.
config
=
config
quant_config
=
vllm_config
.
quant_config
# Check for MoonViT config compatibility
self
.
use_data_parallel
=
(
model_config
.
multimodal_config
.
mm_encoder_tp_mode
==
"data"
)
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
device
=
torch
.
cuda
.
current_device
()
# Build vision tower directly with KimiK25VisionConfig
self
.
vision_tower
=
MoonViT3dPretrainedModel
(
config
.
vision_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
vision_tower
=
self
.
vision_tower
.
to
(
device
=
self
.
device
,
dtype
=
model_config
.
dtype
)
self
.
mm_projector
=
KimiK25MultiModalProjector
(
config
=
config
.
vision_config
,
use_data_parallel
=
self
.
use_data_parallel
,
prefix
=
maybe_prefix
(
prefix
,
"mm_projector"
),
)
self
.
mm_projector
=
self
.
mm_projector
.
to
(
device
=
self
.
device
,
dtype
=
model_config
.
dtype
)
self
.
quant_config
=
quant_config
sub_vllm_config
=
copy
.
deepcopy
(
vllm_config
)
sub_vllm_config
.
model_config
.
hf_config
=
(
sub_vllm_config
.
model_config
.
hf_config
.
text_config
)
self
.
language_model
=
DeepseekV2Model
(
vllm_config
=
sub_vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
text_config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
logit_scale
)
self
.
media_placeholder
:
int
=
self
.
config
.
media_placeholder_token_id
def
_parse_and_validate_media_input
(
self
,
**
kwargs
:
object
)
->
KimiK25MediaPixelInputs
|
None
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
grid_thws
=
kwargs
.
pop
(
"grid_thws"
,
None
)
if
pixel_values
is
None
:
return
None
if
isinstance
(
pixel_values
,
list
):
pixel_values
=
torch
.
cat
(
pixel_values
,
dim
=
0
)
if
len
(
pixel_values
.
shape
)
==
5
or
len
(
pixel_values
.
shape
)
==
3
:
pixel_values
=
pixel_values
.
reshape
(
pixel_values
.
shape
[
0
]
*
pixel_values
.
shape
[
1
],
*
pixel_values
.
shape
[
2
:]
)
# The batch dimension of pixel_values has been flattened into shape[0]
target_dtype
=
next
(
self
.
vision_tower
.
parameters
()).
dtype
pixel_values
=
pixel_values
.
to
(
target_dtype
)
assert
isinstance
(
grid_thws
,
torch
.
Tensor
),
(
f
"expect grid_thws to be a tensor, get
{
type
(
grid_thws
)
}
"
)
# In some cases (e.g. with merger), grid_thws has an extra middle dimension
grid_thws
=
grid_thws
.
reshape
(
-
1
,
grid_thws
.
shape
[
-
1
])
assert
grid_thws
.
ndim
==
2
and
grid_thws
.
size
(
1
)
==
3
,
(
f
"unexpected shape for grid_thws:
{
grid_thws
.
shape
}
"
)
return
KimiK25MediaPixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
grid_thws
=
grid_thws
,
)
def
_process_media_input
(
self
,
media_input
:
KimiK25MediaPixelInputs
)
->
list
[
torch
.
Tensor
]:
# NOTE(moyan): This forward will automatically batch the forward pass internally
media_features
=
vision_tower_forward
(
self
.
vision_tower
,
media_input
[
"pixel_values"
],
media_input
[
"grid_thws"
],
mm_projector
=
self
.
mm_projector
,
use_data_parallel
=
self
.
use_data_parallel
,
)
return
media_features
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
NestedTensors
|
None
:
# Validate the multimodal input keyword arguments
media_input
=
self
.
_parse_and_validate_media_input
(
**
kwargs
)
if
media_input
is
None
:
return
None
# Run multimodal inputs through encoder and projector
vision_embeddings
=
self
.
_process_media_input
(
media_input
)
return
vision_embeddings
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
**
kwargs
)
return
logits
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
config
=
self
.
config
.
text_config
if
not
getattr
(
config
,
"n_routed_experts"
,
None
):
return
[]
return
SharedFusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
config
.
n_routed_experts
,
num_redundant_experts
=
0
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
config
=
self
.
config
.
text_config
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
# mm_projector -> mm_projector mapping
# "mm_projector": "mm_projector",
"mm_projector.proj.0"
:
"mm_projector.linear_1"
,
"mm_projector.proj.2"
:
"mm_projector.linear_2"
,
}
stacked_params_mapping
=
[
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
if
getattr
(
config
,
"kv_lora_rank"
,
None
)
and
getattr
(
config
,
"q_lora_rank"
,
None
):
stacked_params_mapping
+=
[
(
".fused_qkv_a_proj"
,
".q_a_proj"
,
0
),
(
".fused_qkv_a_proj"
,
".kv_a_proj_with_mqa"
,
1
),
]
expert_params_mapping
=
self
.
get_expert_mapping
()
params_dict
=
dict
(
self
.
named_parameters
())
for
args
in
weights
:
name
,
loaded_weight
=
args
[:
2
]
kwargs
=
args
[
2
]
if
len
(
args
)
>
2
else
{}
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
config
,
name
)
if
spec_layer
is
not
None
:
continue
# skip spec decode layers for main model
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
use_default_weight_loading
=
True
else
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
,
**
kwargs
)
break
else
:
for
_
,
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
enumerate
(
expert_params_mapping
):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
expert_id
=
expert_id
,
shard_id
=
shard_id
,
**
kwargs
,
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
**
kwargs
)
def
get_spec_layer_idx_from_weight_name
(
config
:
KimiK25Config
,
weight_name
:
str
)
->
int
|
None
:
if
hasattr
(
config
,
"num_nextn_predict_layers"
)
and
(
config
.
num_nextn_predict_layers
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_nextn_predict_layers
):
# might start with language_model.model.layers
if
f
"model.layers.
{
layer_idx
+
i
}
."
in
weight_name
:
return
layer_idx
+
i
return
None
vllm/model_executor/models/kimi_k25_vit.py
0 → 100644
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Vision tower implementation for Kimi-K2.5 model.
This module provides the vision encoder components for Kimi-K2.5,
including 3D patch embedding, RoPE position embedding, and
temporal pooling for video chunks.
"""
from
collections.abc
import
Sequence
from
copy
import
deepcopy
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers.activations
import
GELUActivation
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
MMEncoderAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.model_executor.models.vision
import
(
is_vit_use_data_parallel
,
run_dp_sharded_mrope_vision_model
,
)
from
vllm.transformers_utils.configs.kimi_k25
import
KimiK25VisionConfig
logger
=
init_logger
(
__name__
)
def
_apply_rope_input_validation
(
x
,
freqs_cis
):
assert
x
.
ndim
==
freqs_cis
.
ndim
+
1
,
(
x
.
shape
,
freqs_cis
.
shape
)
assert
x
.
shape
[:
-
2
]
==
freqs_cis
.
shape
[:
-
1
],
(
x
.
shape
,
freqs_cis
.
shape
)
assert
x
.
shape
[
-
1
]
==
2
*
freqs_cis
.
shape
[
-
1
],
(
x
.
shape
,
freqs_cis
.
shape
)
assert
freqs_cis
.
dtype
==
torch
.
complex64
,
freqs_cis
.
dtype
def
get_rope_shape_decorate
(
func
):
_get_rope_shape_first_call_flag
=
set
()
def
wrapper
(
org
,
interpolation_mode
,
shape
):
key
=
(
org
.
requires_grad
,
torch
.
is_grad_enabled
(),
interpolation_mode
)
if
key
not
in
_get_rope_shape_first_call_flag
:
_get_rope_shape_first_call_flag
.
add
(
key
)
_
=
func
(
org
,
interpolation_mode
,
shape
=
(
64
,
64
))
return
func
(
org
,
interpolation_mode
,
shape
)
return
wrapper
@
get_rope_shape_decorate
@
torch
.
compile
(
dynamic
=
True
)
def
get_rope_shape
(
org
,
interpolation_mode
,
shape
):
return
(
F
.
interpolate
(
org
.
permute
((
2
,
0
,
1
)).
unsqueeze
(
0
),
size
=
shape
,
mode
=
interpolation_mode
,
)
.
squeeze
(
0
)
.
permute
((
1
,
2
,
0
))
.
flatten
(
end_dim
=
1
)
)
def
apply_rope
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args: (The leading dimensions of all inputs should be the same)
xq: query, tensor of shape (..., num_heads, head_dim)
xk: key, tensor of shape (..., num_heads, head_dim)
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64.
Returns:
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
"""
_apply_rope_input_validation
(
xq
,
freqs_cis
)
_apply_rope_input_validation
(
xk
,
freqs_cis
)
freqs_cis
=
freqs_cis
.
unsqueeze
(
-
2
)
# ..., 1, head_dim/2
# ..., num_heads, head_dim/2
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
view
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
view
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
flatten
(
-
2
)
# ..., num_heads, head_dim
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
flatten
(
-
2
)
# ..., num_heads, head_dim
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""Generate 1D sincos positional embedding from grid positions."""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
def
get_1d_sincos_pos_embed
(
embed_dim
,
t_size
,
cls_token
=
False
):
"""Generate 1D sincos positional embedding."""
grid_t
=
np
.
arange
(
t_size
,
dtype
=
np
.
float32
)
pos_embed
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
grid_t
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
class
Learnable2DInterpPosEmbDivided_fixed
(
nn
.
Module
):
"""2D learnable position embedding with temporal extension."""
def
__init__
(
self
,
height
:
int
,
width
:
int
,
num_frames
:
int
,
dim
:
int
,
interpolation_mode
:
str
=
"bicubic"
,
)
->
None
:
super
().
__init__
()
self
.
height
=
height
self
.
width
=
width
self
.
num_frames
=
num_frames
self
.
dim
=
dim
self
.
interpolation_mode
=
interpolation_mode
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
height
,
width
,
dim
))
self
.
register_buffer
(
"time_weight"
,
torch
.
from_numpy
(
get_1d_sincos_pos_embed
(
self
.
dim
,
self
.
num_frames
))
.
float
()
.
unsqueeze
(
1
),
persistent
=
False
,
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
normal_
(
self
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_embs
=
[]
for
t
,
h
,
w
in
grid_thws
.
tolist
():
assert
t
<=
self
.
num_frames
,
f
"t:
{
t
}
> self.num_frames:
{
self
.
num_frames
}
"
if
(
h
,
w
)
==
self
.
weight
.
shape
[:
-
1
]:
pos_emb_2d
=
self
.
weight
.
flatten
(
end_dim
=
1
)
else
:
pos_emb_2d
=
get_rope_shape
(
self
.
weight
,
interpolation_mode
=
self
.
interpolation_mode
,
shape
=
(
h
,
w
),
)
if
t
==
1
:
pos_emb_3d
=
pos_emb_2d
else
:
pos_emb_3d
=
(
pos_emb_2d
.
unsqueeze
(
0
).
repeat
(
t
,
1
,
1
)
+
self
.
time_weight
[
0
:
t
]
)
pos_embs
.
append
(
pos_emb_3d
.
reshape
(
-
1
,
pos_emb_3d
.
shape
[
-
1
]))
out
=
x
+
torch
.
cat
(
pos_embs
)
return
out
class
MoonVision3dPatchEmbed
(
nn
.
Module
):
"""3D patch embedding for vision tower."""
def
__init__
(
self
,
out_dim
:
int
,
in_dim
:
int
=
3
,
patch_size
:
int
|
tuple
[
int
,
int
]
=
(
14
,
14
),
pos_emb_height
:
int
=
14
,
pos_emb_width
:
int
=
14
,
pos_emb_time
:
int
=
4
,
pos_emb_type
:
str
=
"divided_fixed"
,
):
super
().
__init__
()
assert
isinstance
(
patch_size
,
int
|
Sequence
),
(
f
"Invalid patch_size type:
{
type
(
patch_size
)
}
"
)
if
isinstance
(
patch_size
,
int
):
patch_size
=
(
patch_size
,
patch_size
)
assert
len
(
patch_size
)
==
2
,
(
f
"Expected patch_size to be a tuple of 2, got
{
patch_size
}
"
)
self
.
patch_size
=
patch_size
self
.
proj
=
nn
.
Conv2d
(
in_dim
,
out_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
pos_emb_type
==
"divided_fixed"
:
self
.
pos_emb
=
Learnable2DInterpPosEmbDivided_fixed
(
height
=
pos_emb_height
,
width
=
pos_emb_width
,
num_frames
=
pos_emb_time
,
dim
=
out_dim
,
)
else
:
raise
NotImplementedError
(
f
"Not support pos_emb_type:
{
pos_emb_type
}
"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
proj
(
x
).
view
(
x
.
size
(
0
),
-
1
)
# apply positional embedding
x
=
self
.
pos_emb
(
x
,
grid_thws
)
return
x
class
Rope2DPosEmbRepeated
(
nn
.
Module
):
"""2D rotary position embedding with multi-resolution support."""
def
__init__
(
self
,
dim
:
int
,
max_height
:
int
,
max_width
:
int
,
theta_base
=
10000
):
super
().
__init__
()
self
.
dim
=
dim
assert
self
.
dim
%
4
==
0
,
"dim must be divisible by 4"
self
.
max_height
=
max_height
self
.
max_width
=
max_width
self
.
theta_base
=
theta_base
def
extra_repr
(
self
):
return
(
f
"dim=
{
self
.
dim
}
, max_height=
{
self
.
max_height
}
, "
f
"max_width=
{
self
.
max_width
}
, theta_base=
{
self
.
theta_base
}
"
)
def
_precompute_freqs_cis
(
self
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""Calculate the cis(freqs) for each position in the 2D grid."""
N
=
self
.
max_height
*
self
.
max_width
flat_pos
=
torch
.
arange
(
0
,
N
).
float
().
to
(
device
)
x_pos
=
flat_pos
%
self
.
max_width
y_pos
=
flat_pos
//
self
.
max_width
dim_range
=
(
torch
.
arange
(
0
,
self
.
dim
,
4
)[:
(
self
.
dim
//
4
)].
float
().
to
(
device
)
)
# C/4
freqs
=
1.0
/
(
self
.
theta_base
**
(
dim_range
/
self
.
dim
))
x_freqs
=
torch
.
outer
(
x_pos
,
freqs
).
float
()
# N, C/4
y_freqs
=
torch
.
outer
(
y_pos
,
freqs
).
float
()
# N, C/4
x_cis
=
torch
.
polar
(
torch
.
ones_like
(
x_freqs
),
x_freqs
)
# N, C/4
y_cis
=
torch
.
polar
(
torch
.
ones_like
(
y_freqs
),
y_freqs
)
# N, C/4
# N, C/4, 2
freqs_cis
=
torch
.
cat
(
[
x_cis
.
unsqueeze
(
dim
=-
1
),
y_cis
.
unsqueeze
(
dim
=-
1
)],
dim
=-
1
)
# max_height, max_width, C/2
freqs_cis
=
freqs_cis
.
reshape
(
self
.
max_height
,
self
.
max_width
,
-
1
)
return
freqs_cis
def
get_freqs_cis
(
self
,
grid_thws
:
torch
.
Tensor
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Args:
grid_thws (torch.Tensor): grid time, height and width
Returns:
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
"""
if
not
hasattr
(
self
,
"freqs_cis"
):
self
.
register_buffer
(
"freqs_cis"
,
self
.
_precompute_freqs_cis
(
device
),
persistent
=
False
)
shapes
=
grid_thws
.
tolist
()
assert
all
(
1
<=
h
<=
self
.
max_height
and
1
<=
w
<=
self
.
max_width
for
t
,
h
,
w
in
shapes
),
(
shapes
,
self
.
max_height
,
self
.
max_width
,
)
freqs_cis
=
torch
.
cat
(
[
self
.
freqs_cis
[:
h
,
:
w
].
reshape
(
-
1
,
self
.
dim
//
2
).
repeat
(
t
,
1
)
for
t
,
h
,
w
in
shapes
],
dim
=
0
,
)
return
freqs_cis
class
MLP2
(
nn
.
Module
):
"""Two-layer MLP with tensor parallel support."""
def
__init__
(
self
,
dims
:
list
[
int
],
activation
,
bias
:
bool
=
True
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
assert
len
(
dims
)
==
3
self
.
use_data_parallel
=
use_data_parallel
self
.
fc0
=
ColumnParallelLinear
(
dims
[
0
],
dims
[
1
],
bias
=
bias
,
prefix
=
maybe_prefix
(
prefix
,
"fc0"
),
disable_tp
=
self
.
use_data_parallel
,
)
self
.
fc1
=
RowParallelLinear
(
dims
[
1
],
dims
[
2
],
bias
=
bias
,
prefix
=
maybe_prefix
(
prefix
,
"fc1"
),
disable_tp
=
self
.
use_data_parallel
,
)
self
.
activation
=
activation
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc0
(
x
)
x
=
self
.
activation
(
x
)
x
,
_
=
self
.
fc1
(
x
)
return
x
class
MoonViTEncoderLayer
(
nn
.
Module
):
"""Single encoder layer for MoonViT with TP/DP support."""
def
__init__
(
self
,
num_heads
:
int
,
hidden_dim
:
int
,
mlp_dim
:
int
,
prefix
:
str
=
""
,
*
,
activation
=
F
.
gelu
,
attn_bias
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_data_parallel
=
is_vit_use_data_parallel
()
self
.
num_heads
=
num_heads
self
.
hidden_dim
=
hidden_dim
self
.
hidden_size_per_attention_head
=
self
.
hidden_dim
//
self
.
num_heads
self
.
tp_size
=
(
1
if
self
.
use_data_parallel
else
get_tensor_model_parallel_world_size
()
)
self
.
num_attention_heads_per_partition
=
divide
(
num_heads
,
self
.
tp_size
)
self
.
norm0
=
nn
.
LayerNorm
(
hidden_dim
)
self
.
norm1
=
nn
.
LayerNorm
(
hidden_dim
)
self
.
mlp
=
MLP2
(
[
hidden_dim
,
mlp_dim
,
hidden_dim
],
activation
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
self
.
use_data_parallel
,
)
self
.
wqkv
=
QKVParallelLinear
(
hidden_size
=
hidden_dim
,
head_size
=
self
.
hidden_size_per_attention_head
,
total_num_heads
=
num_heads
,
total_num_kv_heads
=
num_heads
,
bias
=
attn_bias
,
prefix
=
f
"
{
prefix
}
.wqkv"
,
disable_tp
=
self
.
use_data_parallel
,
)
self
.
wo
=
RowParallelLinear
(
hidden_dim
,
hidden_dim
,
bias
=
attn_bias
,
prefix
=
f
"
{
prefix
}
.wo"
,
disable_tp
=
self
.
use_data_parallel
,
)
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
def
attention_qkvpacked
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rope_freqs_cis
:
torch
.
Tensor
|
None
=
None
,
):
"""Compute self-attention with packed QKV.
Args:
x (torch.Tensor): (seqlen, hidden_dim)
cu_seqlens (torch.Tensor): cumulative sequence lengths
"""
seq_length
=
x
.
size
(
0
)
xqkv
,
_
=
self
.
wqkv
(
x
)
qkv_shape
=
xqkv
.
size
()[:
-
1
]
+
(
3
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
)
# xqkv: (seqlen, 3, nheads, headdim)
xqkv
=
xqkv
.
view
(
*
qkv_shape
)
xq
,
xk
,
xv
=
torch
.
unbind
(
xqkv
,
dim
=-
3
)
xq
,
xk
=
apply_rope
(
xq
,
xk
,
rope_freqs_cis
)
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
attn_out
=
self
.
attn
(
xq
.
unsqueeze
(
0
),
xk
.
unsqueeze
(
0
),
xv
.
unsqueeze
(
0
),
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
attn_out
=
attn_out
.
reshape
(
seq_length
,
self
.
num_attention_heads_per_partition
*
self
.
hidden_size_per_attention_head
,
)
attn_out
,
_
=
self
.
wo
(
attn_out
)
return
attn_out
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rope_freqs_cis
:
torch
.
Tensor
|
None
=
None
,
):
residual
=
hidden_states
hidden_states
=
self
.
norm0
(
hidden_states
)
hidden_states
=
self
.
attention_qkvpacked
(
hidden_states
,
cu_seqlens
,
rope_freqs_cis
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
MoonViT3dEncoder
(
nn
.
Module
):
"""Full encoder stack for MoonViT 3D."""
def
__init__
(
self
,
hidden_dim
:
int
,
num_layers
:
int
,
block_cfg
:
dict
,
video_attn_type
:
str
=
"spatial_temporal"
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
assert
video_attn_type
==
"spatial_temporal"
,
(
f
'video_attn_type must be "spatial_temporal", got
{
video_attn_type
}
'
)
self
.
video_attn_type
=
video_attn_type
self
.
rope_2d
=
Rope2DPosEmbRepeated
(
block_cfg
[
"hidden_dim"
]
//
block_cfg
[
"num_heads"
],
512
,
512
)
self
.
blocks
=
nn
.
ModuleList
(
[
MoonViTEncoderLayer
(
**
block_cfg
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
num_layers
)
]
)
self
.
final_layernorm
=
nn
.
LayerNorm
(
hidden_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
rope_freqs_cis
=
self
.
rope_2d
.
get_freqs_cis
(
grid_thws
=
grid_thws
,
device
=
hidden_states
.
device
)
lengths
=
torch
.
cat
(
(
torch
.
zeros
(
1
,
dtype
=
grid_thws
.
dtype
,
device
=
grid_thws
.
device
),
grid_thws
[:,
0
]
*
grid_thws
[:,
1
]
*
grid_thws
[:,
2
],
)
)
cu_seqlens
=
lengths
.
to
(
hidden_states
.
device
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
for
block
in
self
.
blocks
:
hidden_states
=
block
(
hidden_states
,
cu_seqlens
,
rope_freqs_cis
=
rope_freqs_cis
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
return
hidden_states
def
tpool_patch_merger
(
x
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
,
merge_kernel_size
:
tuple
[
int
,
int
]
=
(
2
,
2
),
)
->
list
[
torch
.
Tensor
]:
"""Temporal pooling patch merger."""
kh
,
kw
=
merge_kernel_size
lengths
=
(
grid_thws
[:,
0
]
*
grid_thws
[:,
1
]
*
grid_thws
[:,
2
]).
tolist
()
seqs
=
x
.
split
(
lengths
,
dim
=
0
)
outputs
=
[]
for
seq
,
(
t
,
h
,
w
)
in
zip
(
seqs
,
grid_thws
.
tolist
()):
nh
,
nw
=
h
//
kh
,
w
//
kw
# Reshape: (t*h*w, d) -> (t, nh, kh, nw, kw, d)
v
=
seq
.
view
(
t
,
nh
,
kh
,
nw
,
kw
,
-
1
)
# Temporal pooling first (reduces tensor size before permute)
v
=
v
.
mean
(
dim
=
0
)
# (nh, kh, nw, kw, d)
# Spatial rearrangement: (nh, kh, nw, kw, d) -> (nh, nw, kh, kw, d)
out
=
v
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
nh
*
nw
,
kh
*
kw
,
-
1
)
outputs
.
append
(
out
)
return
outputs
class
MoonViT3dPretrainedModel
(
nn
.
Module
):
"""Main vision tower model.
Uses KimiK25VisionConfig directly from transformers_utils/configs/kimi_k25.py.
"""
def
__init__
(
self
,
config
:
KimiK25VisionConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
deepcopy
(
config
)
self
.
config
=
config
# Required for run_dp_sharded_mrope_vision_model
self
.
merge_kernel_size
=
config
.
merge_kernel_size
self
.
patch_size
=
config
.
patch_size
self
.
merge_type
=
config
.
merge_type
self
.
patch_embed
=
MoonVision3dPatchEmbed
(
out_dim
=
config
.
hidden_size
,
patch_size
=
config
.
patch_size
,
pos_emb_height
=
config
.
init_pos_emb_height
,
pos_emb_width
=
config
.
init_pos_emb_width
,
pos_emb_time
=
config
.
init_pos_emb_time
,
pos_emb_type
=
config
.
pos_emb_type
,
)
self
.
encoder
=
MoonViT3dEncoder
(
hidden_dim
=
config
.
hidden_size
,
num_layers
=
config
.
num_hidden_layers
,
block_cfg
=
{
"num_heads"
:
config
.
num_attention_heads
,
"hidden_dim"
:
config
.
hidden_size
,
"mlp_dim"
:
config
.
intermediate_size
,
"activation"
:
get_act_fn
(
"gelu_pytorch_tanh"
),
"attn_bias"
:
True
,
},
video_attn_type
=
config
.
video_attn_type
,
prefix
=
maybe_prefix
(
prefix
,
"encoder"
),
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
pixel_values (torch.Tensor): The input pixel values.
grid_thws (torch.Tensor): Temporal, height and width.
Returns:
torch.Tensor: The output tokens.
"""
hidden_states
=
self
.
patch_embed
(
pixel_values
,
grid_thws
)
hidden_states
=
self
.
encoder
(
hidden_states
,
grid_thws
)
if
(
self
.
merge_type
==
"sd2_tpool"
):
# spatial downsampling 2x with temporal pooling all
hidden_states
=
tpool_patch_merger
(
hidden_states
,
grid_thws
,
merge_kernel_size
=
self
.
merge_kernel_size
)
else
:
raise
NotImplementedError
(
f
"Not support
{
self
.
merge_type
}
"
)
return
hidden_states
@
torch
.
inference_mode
()
def
mm_projector_forward
(
mm_projector
:
torch
.
nn
.
Module
,
vt_output
:
list
[
torch
.
Tensor
]):
"""Apply MM projector to vision tower outputs."""
num_embedding_list
=
[
x
.
shape
[
0
]
for
x
in
vt_output
]
batched
=
torch
.
cat
(
vt_output
,
dim
=
0
)
proj_out
=
mm_projector
(
batched
)
proj_out
=
proj_out
.
reshape
(
-
1
,
proj_out
.
shape
[
-
1
])
proj_out
=
torch
.
split
(
proj_out
,
num_embedding_list
)
return
proj_out
@
torch
.
inference_mode
()
def
vision_tower_forward
(
vision_tower
:
Any
,
pixel_values
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
,
mm_projector
:
Any
,
use_data_parallel
:
bool
,
)
->
list
[
torch
.
Tensor
]:
"""DP-sharded vision tower forward with mrope.
Uses vLLM's standard data parallelism utility to shard the batch
across available GPUs, enabling parallel processing of vision features.
"""
if
use_data_parallel
:
grid_thw_list
=
grid_thw
.
tolist
()
vt_outputs
=
run_dp_sharded_mrope_vision_model
(
vision_model
=
vision_tower
,
pixel_values
=
pixel_values
,
grid_thw_list
=
grid_thw_list
,
rope_type
=
"rope_2d"
,
)
else
:
vt_outputs
=
vision_tower
(
pixel_values
,
grid_thw
)
tensors
=
mm_projector_forward
(
mm_projector
,
list
(
vt_outputs
))
return
list
(
tensors
)
class
KimiK25MultiModalProjector
(
nn
.
Module
):
"""Multi-modal projector with patch merging for Kimi-K2.5."""
def
__init__
(
self
,
config
:
KimiK25VisionConfig
,
use_data_parallel
:
bool
=
False
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
use_data_parallel
=
use_data_parallel
# Hidden size after patch merging
merge_h
,
merge_w
=
config
.
merge_kernel_size
self
.
hidden_size
=
config
.
hidden_size
*
merge_h
*
merge_w
self
.
pre_norm
=
torch
.
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
linear_1
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
prefix
=
maybe_prefix
(
prefix
,
"linear_1"
),
)
self
.
linear_2
=
ReplicatedLinear
(
self
.
hidden_size
,
config
.
mm_hidden_size
,
bias
=
True
,
prefix
=
maybe_prefix
(
prefix
,
"linear_2"
),
)
self
.
act
=
GELUActivation
()
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
pre_norm
(
image_features
).
view
(
-
1
,
self
.
hidden_size
)
hidden_states
,
_
=
self
.
linear_1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
vllm/model_executor/models/kimi_linear.py
View file @
d76fc11e
...
...
@@ -506,7 +506,7 @@ class KimiLinearForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/kimi_vl.py
View file @
d76fc11e
...
...
@@ -389,7 +389,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/lfm2.py
View file @
d76fc11e
...
...
@@ -342,7 +342,7 @@ class Lfm2Model(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -503,7 +503,7 @@ class Lfm2ForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/lfm2_moe.py
View file @
d76fc11e
...
...
@@ -457,7 +457,7 @@ class Lfm2MoeModel(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -730,7 +730,7 @@ class Lfm2MoeForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/lfm2_vl.py
View file @
d76fc11e
...
...
@@ -769,7 +769,7 @@ class Lfm2VLForConditionalGeneration(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/llama.py
View file @
d76fc11e
...
...
@@ -651,7 +651,7 @@ class LlamaForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/llava.py
View file @
d76fc11e
...
...
@@ -662,7 +662,7 @@ class LlavaForConditionalGeneration(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/llava_next.py
View file @
d76fc11e
...
...
@@ -509,7 +509,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/llava_next_video.py
View file @
d76fc11e
...
...
@@ -426,7 +426,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/llava_onevision.py
View file @
d76fc11e
...
...
@@ -887,7 +887,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/longcat_flash.py
View file @
d76fc11e
...
...
@@ -520,7 +520,7 @@ class FlashModel(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -605,7 +605,7 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/longcat_flash_mtp.py
View file @
d76fc11e
...
...
@@ -150,7 +150,7 @@ class LongCatFlashMTP(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
...
...
vllm/model_executor/models/mamba.py
View file @
d76fc11e
...
...
@@ -142,7 +142,7 @@ class MambaModel(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -225,7 +225,7 @@ class MambaForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/mamba2.py
View file @
d76fc11e
...
...
@@ -137,7 +137,7 @@ class Mamba2Model(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -267,7 +267,7 @@ class Mamba2ForCausalLM(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/midashenglm.py
View file @
d76fc11e
...
...
@@ -796,7 +796,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/mimo.py
View file @
d76fc11e
...
...
@@ -61,7 +61,7 @@ logger = init_logger(__name__)
class
MiMoModel
(
Qwen2Model
):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/mimo_mtp.py
View file @
d76fc11e
...
...
@@ -169,7 +169,7 @@ class MiMoMTP(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
...
...
vllm/model_executor/models/mimo_v2_flash.py
View file @
d76fc11e
...
...
@@ -478,7 +478,7 @@ class MiMoV2Model(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -693,7 +693,7 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment