Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7f1bcd18
Unverified
Commit
7f1bcd18
authored
Jan 20, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 20, 2026
Browse files
[3/N] Initialize MM components in context managers (I-L) (#32650)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8be263c3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
202 additions
and
388 deletions
+202
-388
vllm/model_executor/models/interns1.py
vllm/model_executor/models/interns1.py
+13
-18
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+16
-19
vllm/model_executor/models/isaac.py
vllm/model_executor/models/isaac.py
+18
-25
vllm/model_executor/models/kanana_v.py
vllm/model_executor/models/kanana_v.py
+15
-13
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+21
-23
vllm/model_executor/models/kimi_vl.py
vllm/model_executor/models/kimi_vl.py
+27
-188
vllm/model_executor/models/lfm2_vl.py
vllm/model_executor/models/lfm2_vl.py
+25
-27
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+23
-25
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+24
-28
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+20
-22
No files found.
vllm/model_executor/models/interns1.py
View file @
7f1bcd18
...
...
@@ -547,20 +547,20 @@ class InternS1ForConditionalGeneration(
)
self
.
downsample_ratio
=
config
.
downsample_ratio
self
.
llm_arch_name
=
config
.
text_config
.
architectures
[
0
]
self
.
vision_tower
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
vision_tower
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
multi_modal_projector
=
self
.
_init_mlp1
(
config
)
self
.
multi_modal_projector
=
self
.
_init_mlp1
(
config
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
img_context_token_id
=
None
self
.
video_context_token_id
=
None
...
...
@@ -699,8 +699,6 @@ class InternS1ForConditionalGeneration(
):
return
image_input
[
"data"
]
assert
self
.
vision_tower
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"pixel_values"
])
num_patches
=
image_input
[
"num_patches"
]
...
...
@@ -737,9 +735,6 @@ class InternS1ForConditionalGeneration(
def
_set_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
None
:
self
.
visual_token_mask
=
None
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
...
...
vllm/model_executor/models/internvl.py
View file @
7f1bcd18
...
...
@@ -1092,22 +1092,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
self
.
downsample_ratio
=
config
.
downsample_ratio
self
.
ps_version
=
config
.
ps_version
self
.
llm_arch_name
=
config
.
text_config
.
architectures
[
0
]
self
.
is_mono
=
self
.
llm_arch_name
==
"InternLM2VEForCausalLM"
self
.
vision_model
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
is_mono
=
self
.
is_mono
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
llm_arch_name
=
config
.
text_config
.
architectures
[
0
]
self
.
is_mono
=
llm_arch_name
==
"InternLM2VEForCausalLM"
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
vision_model
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
is_mono
=
self
.
is_mono
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
img_context_token_id
=
None
self
.
video_context_token_id
=
None
...
...
@@ -1281,8 +1283,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
):
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"pixel_values_flat"
])
num_patches
=
image_input
[
"num_patches"
]
...
...
@@ -1325,9 +1325,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
else
:
self
.
visual_token_mask
=
None
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
...
...
vllm/model_executor/models/isaac.py
View file @
7f1bcd18
...
...
@@ -1342,11 +1342,14 @@ class IsaacForConditionalGeneration(
"mrope_interleaved"
,
rope_scaling
[
"mrope_interleaved"
]
)
target_cfg
.
rope_parameters
=
rope_parameters
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
architectures
=
[
"Qwen3ForCausalLM"
],
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
architectures
=
[
"Qwen3ForCausalLM"
],
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -1363,14 +1366,16 @@ class IsaacForConditionalGeneration(
vision_cfg
.
_attn_implementation
=
attn_impl
hidden_dim
=
vision_cfg
.
hidden_size
*
(
vision_cfg
.
pixel_shuffle_scale_factor
**
2
)
self
.
vision_embedding
=
IsaacVisionEmbedding
(
vision_cfg
=
vision_cfg
,
hidden_dim
=
hidden_dim
,
output_dim
=
config
.
hidden_size
,
quant_config
=
quant_config
,
multimodal_config
=
self
.
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_embedding"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_embedding
=
IsaacVisionEmbedding
(
vision_cfg
=
vision_cfg
,
hidden_dim
=
hidden_dim
,
output_dim
=
config
.
hidden_size
,
quant_config
=
quant_config
,
multimodal_config
=
self
.
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_embedding"
),
)
def
iter_mm_grid_hw
(
self
,
input_tokens
:
list
[
int
],
mm_features
:
list
[
MultiModalFeatureSpec
]
...
...
@@ -1457,18 +1462,6 @@ class IsaacForConditionalGeneration(
return
()
return
self
.
_process_image_input
(
image_input
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
# Backward compatibility for older runners.
embeddings
=
self
.
embed_multimodal
(
**
kwargs
)
if
not
embeddings
:
return
[]
return
embeddings
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/kanana_v.py
View file @
7f1bcd18
...
...
@@ -586,16 +586,21 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
vision_model
=
CustomQwen2VLVE
.
_from_config
(
config
.
vision_config
)
self
.
abstractor
=
DynamicCAbstractor
(
config
.
projector_config
,
num_input_tokens
=
self
.
vision_model
.
get_num_tokens
()
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
architectures
=
[
"LlamaForCausalLM"
],
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_model
=
CustomQwen2VLVE
.
_from_config
(
config
.
vision_config
)
self
.
abstractor
=
DynamicCAbstractor
(
config
.
projector_config
,
num_input_tokens
=
self
.
vision_model
.
get_num_tokens
(),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
architectures
=
[
"LlamaForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -718,9 +723,6 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
visual_embeds
=
self
.
forward_projector
(
visual_features
,
image_metas
=
image_metas
)
return
visual_embeds
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
...
...
vllm/model_executor/models/keye.py
View file @
7f1bcd18
...
...
@@ -1242,7 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return
_keye_field_config
(
hf_inputs
)
class
BaseKeyeModule
(
nn
.
Module
):
class
BaseKeyeModule
(
nn
.
Module
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -1280,25 +1280,26 @@ class BaseKeyeModule(nn.Module):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
visual
=
KeyeSiglipVisionModel
(
config
.
vision_config
,
quant_config
=
quant
_config
,
multimodal_config
=
multimodal
_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
)
,
)
self
.
mlp_AR
=
self
.
_build_projector
(
config
,
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"mlp_AR"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
visual
=
KeyeSiglipVisionModel
(
config
.
vision
_config
,
quant_config
=
quant
_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
self
.
mlp_AR
=
self
.
_build_projector
(
config
,
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"mlp_AR"
),
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen3ForCausalLM"
],
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen3ForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
...
...
@@ -1312,7 +1313,7 @@ class BaseKeyeModule(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
raise
Value
Error
(
"Need projector"
)
raise
NotImplemented
Error
(
"Need projector"
)
def
_process_image_input
(
self
,
image_input
:
Any
)
->
tuple
[
torch
.
Tensor
,
...]:
siglip_position_ids
=
list
()
...
...
@@ -1429,9 +1430,6 @@ class BaseKeyeModule(nn.Module):
return
modalities
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
...
...
vllm/model_executor/models/kimi_vl.py
View file @
7f1bcd18
...
...
@@ -42,7 +42,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
copy
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
...
...
@@ -50,23 +49,12 @@ from typing import Annotated, Any, Literal
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
,
DeepseekV2Config
from
transformers
import
BatchFeature
from
transformers.activations
import
GELUActivation
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_pp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV2Model
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
,
SupportsPP
from
vllm.model_executor.models.moonvit
import
MoonVitPretrainedModel
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -92,7 +80,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs
import
KimiVLConfig
,
MoonViTConfig
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
from
.vision
import
run_dp_sharded_mrope_vision_model
...
...
@@ -315,48 +303,41 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
super
().
__init__
()
model_config
=
vllm_config
.
model_config
config
:
KimiVLConfig
=
model_config
.
hf_config
self
.
config
=
config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
assert
isinstance
(
config
.
vision_config
,
MoonViTConfig
)
self
.
use_data_parallel
=
(
model_config
.
multimodal_config
.
mm_encoder_tp_mode
==
"data"
)
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
vision_tower
=
MoonVitPretrainedModel
(
config
.
vision_config
,
multimodal_config
=
model_config
.
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
multi_modal_projector
=
KimiVLMultiModalProjector
(
config
=
config
,
use_data_parallel
=
self
.
use_data_parallel
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_tower
=
MoonVitPretrainedModel
(
config
.
vision_config
,
multimodal_config
=
model_config
.
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
multi_modal_projector
=
KimiVLMultiModalProjector
(
config
=
config
,
use_data_parallel
=
self
.
use_data_parallel
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
self
.
quant_config
=
quant_config
sub_vllm_config
=
copy
.
deepcopy
(
vllm_config
)
sub_vllm_config
.
model_config
.
hf_config
=
(
sub_vllm_config
.
model_config
.
hf_config
.
text_config
)
self
.
language_model
=
DeepseekV2Model
(
vllm_config
=
sub_vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
text_config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"DeepseekV2ForCausalLM"
],
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
logit_scale
)
self
.
media_placeholder
:
int
=
self
.
config
.
media_placeholder_token_id
def
_parse_and_validate_image_input
(
...
...
@@ -378,8 +359,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# perform vt on processored pixel_values
@
torch
.
inference_mode
()
def
_process_image_pixels
(
self
,
inputs
:
KimiVLImagePixelInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"pixel_values"
]
image_grid_hws
=
inputs
[
"image_grid_hws"
]
if
self
.
use_data_parallel
:
...
...
@@ -399,9 +378,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
lengths
=
[
x
.
shape
[
0
]
for
x
in
image_features
]
return
self
.
multi_modal_projector
(
torch
.
cat
(
image_features
)).
split
(
lengths
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
NestedTensors
|
None
:
# Validate the multimodal input keyword arguments
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
@@ -433,145 +409,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
**
kwargs
)
return
logits
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
config
=
self
.
config
.
text_config
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
use_mha
=
(
config
.
model_type
==
"deepseek"
or
config
.
qk_nope_head_dim
+
config
.
qk_rope_head_dim
==
0
)
if
use_mha
:
stacked_params_mapping
+=
[
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
if
getattr
(
config
,
"n_routed_experts"
,
None
):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
config
.
n_routed_experts
,
)
else
:
expert_params_mapping
=
[]
params_dict
=
dict
(
self
.
named_parameters
())
for
args
in
weights
:
name
,
loaded_weight
=
args
[:
2
]
kwargs
=
args
[
2
]
if
len
(
args
)
>
2
else
{}
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
config
,
name
)
if
spec_layer
is
not
None
:
continue
# skip spec decode layers for main model
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
,
**
kwargs
)
break
else
:
for
idx
,
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
enumerate
(
expert_params_mapping
):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
expert_id
=
expert_id
,
shard_id
=
shard_id
,
**
kwargs
,
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
**
kwargs
)
def
get_spec_layer_idx_from_weight_name
(
config
:
DeepseekV2Config
,
weight_name
:
str
)
->
int
|
None
:
if
hasattr
(
config
,
"num_nextn_predict_layers"
)
and
(
config
.
num_nextn_predict_layers
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_nextn_predict_layers
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
return
layer_idx
+
i
return
None
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/lfm2_vl.py
View file @
7f1bcd18
...
...
@@ -546,38 +546,37 @@ class Lfm2VLForConditionalGeneration(
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
vision_config
.
model_type
==
"siglip2_vision_model"
:
self
.
vision_tower
=
Siglip2Model
(
config
=
vision_config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
else
:
raise
ValueError
(
f
"Unsupported visual tokenizer model_type:
{
vision_config
.
model_type
}
"
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
if
vision_config
.
model_type
==
"siglip2_vision_model"
:
self
.
vision_tower
=
Siglip2Model
(
config
=
vision_config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
else
:
raise
ValueError
(
f
"Unsupported visual tokenizer type:
{
vision_config
.
model_type
}
"
)
self
.
multi_modal_projector
=
Lfm2VLMultiModalProjector
(
config
=
config
,
use_data_parallel
=
self
.
use_data_parallel
,
prefix
=
f
"
{
prefix
}
.
multi_modal_projector"
,
)
self
.
multi_modal_projector
=
Lfm2VLMultiModalProjector
(
config
=
config
,
use_data_parallel
=
self
.
use_data_parallel
,
prefix
=
maybe_prefix
(
prefix
,
"
multi_modal_projector"
)
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language"
),
architectures
=
config
.
text_config
.
architectures
,
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language"
),
architectures
=
config
.
text_config
.
architectures
,
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
LFM2VLImageInputs
|
None
:
...
...
@@ -714,8 +713,7 @@ class Lfm2VLForConditionalGeneration(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
language_model
.
compute_logits
(
hidden_states
)
return
logits
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/llava_next.py
View file @
7f1bcd18
...
...
@@ -268,27 +268,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
vision_hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
config
.
multimodal_projector_bias
,
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
)
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
vision_hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
config
.
multimodal_projector_bias
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
...
...
@@ -427,8 +430,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self
,
inputs
:
LlavaNextImagePixelInputs
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"pixel_values"
]
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
...
...
@@ -480,9 +481,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
for
i
,
patch_features_batch
in
enumerate
(
patch_embeddings
)
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
...
...
vllm/model_executor/models/llava_next_video.py
View file @
7f1bcd18
...
...
@@ -312,12 +312,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
if
modality
.
startswith
(
"image"
):
return
"<image>"
if
modality
.
startswith
(
"video"
):
return
"<video>"
raise
ValueError
(
"Only
image or
video modality is supported"
)
raise
ValueError
(
"Only video modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
...
...
@@ -329,26 +327,29 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
config
.
multimodal_projector_bias
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
"video"
):
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
config
.
multimodal_projector_bias
,
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
...
...
@@ -395,8 +396,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return
image_features
def
_process_video_pixels
(
self
,
inputs
:
LlavaNextVideoPixelInputs
):
assert
self
.
vision_tower
is
not
None
video_pixels
=
inputs
[
"pixel_values_videos"
]
if
isinstance
(
video_pixels
,
torch
.
Tensor
):
...
...
@@ -419,9 +418,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return
[
e
.
flatten
(
0
,
1
)
for
e
in
embeds
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
video_input
is
None
:
...
...
vllm/model_executor/models/llava_onevision.py
View file @
7f1bcd18
...
...
@@ -508,21 +508,26 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
)
)
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
...
...
@@ -726,8 +731,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self
,
inputs
:
LlavaOnevisionImagePixelInputs
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"pixel_values"
]
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
...
...
@@ -801,8 +804,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return
video_features
def
_process_video_pixels
(
self
,
inputs
:
LlavaOnevisionVideoPixelInputs
):
assert
self
.
vision_tower
is
not
None
video_pixels
=
inputs
[
"pixel_values_videos"
]
if
isinstance
(
video_pixels
,
torch
.
Tensor
):
...
...
@@ -862,9 +863,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
image_feature
=
image_feature
.
view
(
batch_frames
,
-
1
,
dim
)
return
image_feature
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment