Unverified Commit 302ecf64 authored by Eduardo Salinas's avatar Eduardo Salinas Committed by GitHub
Browse files

[Models]: lfm2_siglip2 return intermediate encoder layers (#33370)


Signed-off-by: default avatarEduardo Salinas <edus@microsoft.com>
parent b6bb2842
...@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import ( ...@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit from .vision import (
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
should_torch_compile_mm_vit,
)
class Siglip2VisionEmbeddings(nn.Module): class Siglip2VisionEmbeddings(nn.Module):
...@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module): ...@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Siglip2EncoderLayer( Siglip2EncoderLayer(
...@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module): ...@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}", prefix=f"{prefix}.layers.{idx}",
) )
for idx in range(config.num_hidden_layers) for idx in range(num_hidden_layers)
] ]
) )
...@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module): ...@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module):
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen: int | torch.Tensor, max_seqlen: int | torch.Tensor,
) -> torch.Tensor: return_all_hidden_states: bool = False,
) -> torch.Tensor | list[torch.Tensor]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( hidden_states = encoder_layer(
hidden_states, hidden_states,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
) )
hidden_states = layer_outputs if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
if return_all_hidden_states:
return hidden_states_pool
return hidden_states return hidden_states
...@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder( self.encoder = Siglip2Encoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
...@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module):
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
...@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module):
spatial_shapes: torch.LongTensor, spatial_shapes: torch.LongTensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
select_layers: list[int] | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) Tensor containing the spatial dimensions (height, width)
of the input images. of the input images.
select_layers (`list[int]` or `None`, defaults to `None`):
Layer indices to select hidden states from. Supports negative
indices (e.g., -1 for last layer, -2 for second-to-last).
If None, returns the last layer output.
""" """
hidden_states = self.embeddings(pixel_values_packed, spatial_shapes) hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
return_all_hidden_states=select_layers is not None,
) )
return self.post_layernorm(encoder_outputs)
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs,
self.post_layernorm,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
)
return encoder_outputs
class Siglip2Model(torch.nn.Module): class Siglip2Model(torch.nn.Module):
...@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module): ...@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module): ...@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer( self.vision_model = Siglip2VisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
) )
...@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module): ...@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module):
spatial_shapes: torch.LongTensor, spatial_shapes: torch.LongTensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
select_layers: list[int] | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass through the vision model.
Args:
select_layers: Layer indices to select hidden states from.
Supports negative indices (e.g., [-2] for second-to-last).
If None, returns the last layer output with post_layernorm.
Multiple layers can be selected and will be concatenated.
"""
return self.vision_model( return self.vision_model(
pixel_values_packed=pixel_values_packed, pixel_values_packed=pixel_values_packed,
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
select_layers=select_layers,
) )
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
...@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module): ...@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is optional in Siglip2Model
if (
name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None
):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment