Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
302ecf64
Unverified
Commit
302ecf64
authored
Feb 01, 2026
by
Eduardo Salinas
Committed by
GitHub
Feb 01, 2026
Browse files
[Models]: lfm2_siglip2 return intermediate encoder layers (#33370)
Signed-off-by:
Eduardo Salinas
<
edus@microsoft.com
>
parent
b6bb2842
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
7 deletions
+76
-7
vllm/model_executor/models/lfm2_siglip2.py
vllm/model_executor/models/lfm2_siglip2.py
+76
-7
No files found.
vllm/model_executor/models/lfm2_siglip2.py
View file @
302ecf64
...
@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import (
...
@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.vision
import
is_vit_use_data_parallel
,
should_torch_compile_mm_vit
from
.vision
import
(
is_vit_use_data_parallel
,
resolve_visual_encoder_outputs
,
should_torch_compile_mm_vit
,
)
class
Siglip2VisionEmbeddings
(
nn
.
Module
):
class
Siglip2VisionEmbeddings
(
nn
.
Module
):
...
@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module):
...
@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module):
self
,
self
,
config
:
Siglip2VisionConfig
,
config
:
Siglip2VisionConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
num_hidden_layers_override
:
int
|
None
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
if
num_hidden_layers_override
is
None
:
num_hidden_layers
=
config
.
num_hidden_layers
else
:
num_hidden_layers
=
num_hidden_layers_override
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
Siglip2EncoderLayer
(
Siglip2EncoderLayer
(
...
@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module):
...
@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
)
for
idx
in
range
(
config
.
num_hidden_layers
)
for
idx
in
range
(
num_hidden_layers
)
]
]
)
)
...
@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module):
...
@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module):
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
int
|
torch
.
Tensor
,
max_seqlen
:
int
|
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return_all_hidden_states
:
bool
=
False
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
hidden_states_pool
=
[
inputs_embeds
]
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
for
encoder_layer
in
self
.
layers
:
layer_output
s
=
encoder_layer
(
hidden_state
s
=
encoder_layer
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
)
)
hidden_states
=
layer_outputs
if
return_all_hidden_states
:
hidden_states_pool
.
append
(
hidden_states
)
if
return_all_hidden_states
:
return
hidden_states_pool
return
hidden_states
return
hidden_states
...
@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module):
...
@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module):
self
,
self
,
config
:
Siglip2VisionConfig
,
config
:
Siglip2VisionConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
num_hidden_layers_override
:
int
|
None
=
None
,
require_post_norm
:
bool
|
None
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module):
...
@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module):
self
.
encoder
=
Siglip2Encoder
(
self
.
encoder
=
Siglip2Encoder
(
config
,
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
prefix
=
f
"
{
prefix
}
.encoder"
,
prefix
=
f
"
{
prefix
}
.encoder"
,
)
)
num_hidden_layers
=
config
.
num_hidden_layers
num_hidden_layers
=
config
.
num_hidden_layers
...
@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module):
...
@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module):
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
)
if
require_post_norm
is
None
:
require_post_norm
=
len
(
self
.
encoder
.
layers
)
==
num_hidden_layers
if
require_post_norm
:
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
else
:
self
.
post_layernorm
=
None
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
embeddings
return
self
.
embeddings
...
@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module):
...
@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module):
spatial_shapes
:
torch
.
LongTensor
,
spatial_shapes
:
torch
.
LongTensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
select_layers
:
list
[
int
]
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""
r
"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
Tensor containing the spatial dimensions (height, width)
of the input images.
of the input images.
select_layers (`list[int]` or `None`, defaults to `None`):
Layer indices to select hidden states from. Supports negative
indices (e.g., -1 for last layer, -2 for second-to-last).
If None, returns the last layer output.
"""
"""
hidden_states
=
self
.
embeddings
(
pixel_values_packed
,
spatial_shapes
)
hidden_states
=
self
.
embeddings
(
pixel_values_packed
,
spatial_shapes
)
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
inputs_embeds
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
return_all_hidden_states
=
select_layers
is
not
None
,
)
)
return
self
.
post_layernorm
(
encoder_outputs
)
encoder_outputs
=
resolve_visual_encoder_outputs
(
encoder_outputs
,
self
.
post_layernorm
,
select_layers
=
select_layers
,
max_possible_layers
=
self
.
config
.
num_hidden_layers
,
)
return
encoder_outputs
class
Siglip2Model
(
torch
.
nn
.
Module
):
class
Siglip2Model
(
torch
.
nn
.
Module
):
...
@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module):
...
@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module):
self
,
self
,
config
:
Siglip2VisionConfig
,
config
:
Siglip2VisionConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
num_hidden_layers_override
:
int
|
None
=
None
,
require_post_norm
:
bool
|
None
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module):
...
@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module):
self
.
vision_model
=
Siglip2VisionTransformer
(
self
.
vision_model
=
Siglip2VisionTransformer
(
config
,
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
require_post_norm
=
require_post_norm
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
)
)
...
@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module):
...
@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module):
spatial_shapes
:
torch
.
LongTensor
,
spatial_shapes
:
torch
.
LongTensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
select_layers
:
list
[
int
]
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass through the vision model.
Args:
select_layers: Layer indices to select hidden states from.
Supports negative indices (e.g., [-2] for second-to-last).
If None, returns the last layer output with post_layernorm.
Multiple layers can be selected and will be concatenated.
"""
return
self
.
vision_model
(
return
self
.
vision_model
(
pixel_values_packed
=
pixel_values_packed
,
pixel_values_packed
=
pixel_values_packed
,
spatial_shapes
=
spatial_shapes
,
spatial_shapes
=
spatial_shapes
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
select_layers
=
select_layers
,
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module):
...
@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in Siglip2Model
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
name
.
startswith
(
"vision_model.encoder.layers"
):
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment