Unverified Commit d37b3787 authored by Ilya Boytsov's avatar Ilya Boytsov Committed by GitHub
Browse files

[Model] Update ColModernVBERT to support latest HF checkpoint (#39307)


Signed-off-by: default avatarIlya Boytsov <ilyaboytsov1805@gmail.com>
parent 92fbec39
...@@ -15,10 +15,6 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score ...@@ -15,10 +15,6 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
MODEL_NAME = "ModernVBERT/colmodernvbert-merged" MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
COLBERT_DIM = 128 COLBERT_DIM = 128
DTYPE = "half" DTYPE = "half"
# Fixme:
# Update colmodernvbert code to support the latest HF version
# and remove revision set.
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
...@@ -30,7 +26,6 @@ def test_colmodernvbert_text_token_embed(vllm_runner): ...@@ -30,7 +26,6 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
"""Text query produces per-token embeddings with shape (seq_len, 128).""" """Text query produces per-token embeddings with shape (seq_len, 128)."""
with vllm_runner( with vllm_runner(
MODEL_NAME, MODEL_NAME,
revision=REVISION,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
enforce_eager=True, enforce_eager=True,
...@@ -54,7 +49,6 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner): ...@@ -54,7 +49,6 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
with vllm_runner( with vllm_runner(
MODEL_NAME, MODEL_NAME,
revision=REVISION,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
enforce_eager=True, enforce_eager=True,
...@@ -72,7 +66,6 @@ def test_colmodernvbert_text_late_interaction(vllm_runner): ...@@ -72,7 +66,6 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
with vllm_runner( with vllm_runner(
MODEL_NAME, MODEL_NAME,
revision=REVISION,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
enforce_eager=True, enforce_eager=True,
...@@ -99,7 +92,6 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets): ...@@ -99,7 +92,6 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
"""Image input produces per-token embeddings including vision tokens.""" """Image input produces per-token embeddings including vision tokens."""
with vllm_runner( with vllm_runner(
MODEL_NAME, MODEL_NAME,
revision=REVISION,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
enforce_eager=True, enforce_eager=True,
......
...@@ -648,7 +648,6 @@ _LATE_INTERACTION_EXAMPLE_MODELS = { ...@@ -648,7 +648,6 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
# [Multimodal] # [Multimodal]
"ColModernVBertForRetrieval": _HfExamplesInfo( "ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged", "ModernVBERT/colmodernvbert-merged",
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
), ),
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"), "ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
"ColQwen3": _HfExamplesInfo( "ColQwen3": _HfExamplesInfo(
......
...@@ -18,7 +18,6 @@ from vllm.config import VllmConfig ...@@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
...@@ -358,70 +357,23 @@ class ColModernVBertForRetrieval( ...@@ -358,70 +357,23 @@ class ColModernVBertForRetrieval(
"model.text_model.layers.": "text_layers.", "model.text_model.layers.": "text_layers.",
"model.text_model.embeddings.": "text_embeddings.", "model.text_model.embeddings.": "text_embeddings.",
"model.text_model.final_norm.": "text_final_norm.", "model.text_model.final_norm.": "text_final_norm.",
"model.connector.modality_projection.": "connector.", "model.connector.modality_projection.": "connector.proj.",
"model.custom_text_proj.": "custom_text_proj.", "model.custom_text_proj.": "custom_text_proj.",
"model.vision_model.": "vision_model.vision_model.", "model.vision_model.vision_model.": "vision_model.vision_model.",
"model.": "", "model.": "",
}, },
) )
# Checkpoint names for DecoupledEmbedding parts
_BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
_EXTRA_EMB = (
"model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
)
def load_weights( def load_weights(
self, self,
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]: ) -> set[str]:
# DecoupledEmbedding requires concatenating base + additional
# embedding tensors before loading, so we extract them first.
base_embedding_weight: torch.Tensor | None = None
additional_embedding_weight: torch.Tensor | None = None
remaining: list[tuple[str, torch.Tensor]] = []
for name, tensor in weights:
if name == self._BASE_EMB:
base_embedding_weight = tensor
elif name == self._EXTRA_EMB:
additional_embedding_weight = tensor
else:
remaining.append((name, tensor))
# Load all non-embedding weights via AutoWeightsLoader
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights( loaded_params = loader.load_weights(
remaining, weights,
mapper=self.hf_to_vllm_mapper, mapper=self.hf_to_vllm_mapper,
) )
# Concatenate and load DecoupledEmbedding weights
if base_embedding_weight is not None:
combined = base_embedding_weight
if additional_embedding_weight is not None:
combined = torch.cat(
[base_embedding_weight, additional_embedding_weight],
dim=0,
)
param_name = "text_embeddings.tok_embeddings.weight"
params_dict = dict(self.named_parameters())
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, combined)
loaded_params.add(param_name)
elif additional_embedding_weight is not None:
raise ValueError(
"Found 'text_model.embeddings.tok_embeddings"
".additional_embedding.weight' but not "
"'text_model.embeddings.tok_embeddings.weight'"
)
# The pooler wraps ``custom_text_proj`` as its head projector. # The pooler wraps ``custom_text_proj`` as its head projector.
# Mark those params as loaded under the pooler path too. # Mark those params as loaded under the pooler path too.
if hasattr(self, "pooler") and hasattr(self.pooler, "head"): if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
......
...@@ -82,7 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -82,7 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
bagel="BagelConfig", bagel="BagelConfig",
umm="CheersConfig", umm="CheersConfig",
chatglm="ChatGLMConfig", chatglm="ChatGLMConfig",
colmodernvbert="ColModernVBertConfig", modernvbert="ColModernVBertConfig",
colpali="ColPaliConfig", colpali="ColPaliConfig",
colqwen3="ColQwen3Config", colqwen3="ColQwen3Config",
ops_colqwen3="OpsColQwen3Config", ops_colqwen3="OpsColQwen3Config",
......
...@@ -17,43 +17,41 @@ class ColModernVBertConfig(PretrainedConfig): ...@@ -17,43 +17,41 @@ class ColModernVBertConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
vlm_config: dict | None = None, image_token_id: int = 50407,
pixel_shuffle_factor: int = 4,
text_config: dict | None = None,
vision_config: dict | None = None,
**kwargs, **kwargs,
): ):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.image_token_id = image_token_id
self.pixel_shuffle_factor = pixel_shuffle_factor
if vlm_config is None: text_config = text_config or {}
vlm_config = {} self.hidden_size = text_config.get("hidden_size", 768)
# Top-level VLM fields
self.image_token_id = vlm_config.get("image_token_id", 50407)
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
self.hidden_size = vlm_config.get("hidden_size", 768)
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
# Text config (ModernBERT)
text_cfg = vlm_config.get("text_config", {})
base_vocab = text_cfg.get("vocab_size", 50368)
self.text_config = ModernBertConfig( self.text_config = ModernBertConfig(
vocab_size=base_vocab + additional_vocab_size, vocab_size=text_config.get("vocab_size", 50408),
hidden_size=text_cfg.get("hidden_size", 768), hidden_size=text_config.get("hidden_size", 768),
intermediate_size=text_cfg.get("intermediate_size", 1152), intermediate_size=text_config.get("intermediate_size", 1152),
num_hidden_layers=text_cfg.get("num_hidden_layers", 22), num_hidden_layers=text_config.get("num_hidden_layers", 22),
num_attention_heads=text_cfg.get("num_attention_heads", 12), num_attention_heads=text_config.get("num_attention_heads", 12),
mlp_bias=text_cfg.get("mlp_bias", False), mlp_bias=text_config.get("mlp_bias", False),
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192), max_position_embeddings=text_config.get("max_position_embeddings", 8192),
) )
# Vision config (SigLIP) vision_config = vision_config or {}
vis_cfg = vlm_config.get("vision_config", {})
self.vision_config = SiglipVisionConfig( self.vision_config = SiglipVisionConfig(
hidden_size=vis_cfg.get("embed_dim", 768), hidden_size=vision_config.get("hidden_size", 768),
image_size=vis_cfg.get("image_size", 512), image_size=vision_config.get("image_size", 512),
patch_size=vis_cfg.get("patch_size", 16), patch_size=vision_config.get("patch_size", 16),
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12), num_hidden_layers=vision_config.get("num_hidden_layers", 12),
intermediate_size=vis_cfg.get("intermediate_size", 3072), intermediate_size=vision_config.get("intermediate_size", 3072),
num_attention_heads=vis_cfg.get("num_attention_heads", 12), num_attention_heads=vision_config.get("num_attention_heads", 12),
) )
# Ensure architectures is set so vLLM routes to our model class
kwargs.setdefault("architectures", ["ColModernVBertForRetrieval"])
super().__init__(**kwargs) super().__init__(**kwargs)
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment