Unverified Commit cd4cfee6 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model][1/N] Automatic conversion of CrossEncoding model (#20012)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent e1109306
...@@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder): ...@@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder):
# issues by randomizing the order. # issues by randomizing the order.
r = self.rng.permutation(len(sentences)) r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r] sentences = [sentences[i] for i in r]
outputs = self.model.encode(sentences, use_tqdm=False) outputs = self.model.embed(sentences, use_tqdm=False)
embeds = np.array(outputs) embeds = np.array(outputs)
embeds = embeds[np.argsort(r)] embeds = embeds[np.argsort(r)]
return embeds return embeds
...@@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner, ...@@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="score", task="score",
max_model_len=None, max_model_len=None,
max_num_seqs=8,
**vllm_extra_kwargs) as vllm_model: **vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.model.llm_engine.model_config
if model_info.architecture: if model_info.architecture:
assert (model_info.architecture assert (model_info.architecture in model_config.architectures)
in vllm_model.model.llm_engine.model_config.architectures) assert model_config.hf_config.num_labels == 1
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
tasks=MTEB_RERANK_TASKS, tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS) languages=MTEB_RERANK_LANGS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype vllm_dtype = model_config.dtype
with hf_runner(model_info.name, is_cross_encoder=True, with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model: dtype="float32") as hf_model:
......
...@@ -569,6 +569,10 @@ class ModelConfig: ...@@ -569,6 +569,10 @@ class ModelConfig:
else: else:
self.truncation_side = "right" self.truncation_side = "right"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config() self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype( self.dtype = _get_and_verify_dtype(
...@@ -660,8 +664,18 @@ class ModelConfig: ...@@ -660,8 +664,18 @@ class ModelConfig:
@property @property
def architectures(self) -> list[str]: def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", []) return getattr(self.hf_config, "architectures", [])
@property
def architecture(self) -> str:
# The architecture vllm actually used.
return self._architecture
@property
def model_info(self) -> dict[str, Any]:
return self._model_info
def maybe_pull_model_tokenizer_for_s3(self, model: str, def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None: tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed. """Pull model/tokenizer from S3 to temporary directory when needed.
...@@ -4450,6 +4464,9 @@ class VllmConfig: ...@@ -4450,6 +4464,9 @@ class VllmConfig:
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """
self.try_verify_and_update_config()
if self.model_config is not None: if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config, self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config, self.speculative_config,
...@@ -4694,11 +4711,21 @@ class VllmConfig: ...@@ -4694,11 +4711,21 @@ class VllmConfig:
batch_size_capture_list) batch_size_capture_list)
def recalculate_max_model_len(self, max_model_len: int): def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len) max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len self.scheduler_config.max_model_len = max_model_len
self.compute_hash()
def try_verify_and_update_config(self):
architecture = getattr(self.model_config, "architecture", None)
if architecture is None:
return
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)
def __str__(self): def __str__(self):
return ( return (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from copy import deepcopy
from typing import Optional from typing import Optional
import torch import torch
...@@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType ...@@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (get_act_and_mul_fn, from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn) get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant ...@@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class BertWithRopeEmbedding(nn.Module): class BertWithRopeEmbedding(nn.Module):
...@@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config) self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config) self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder( self.encoder = BertWithRopeEncoder(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
rotary_kwargs=self.config.rotary_kwargs, rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
def config_verify(self, vllm_config):
raise NotImplementedError
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
...@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope): ...@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
"norm2": "mlp_ln", "norm2": "mlp_ln",
}) })
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)
vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
vllm_config.recalculate_max_model_len(max_model_len)
return config
class GteNewModel(BertWithRope): class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl # for https://huggingface.co/Alibaba-NLP/new-impl
...@@ -600,24 +504,6 @@ class GteNewModel(BertWithRope): ...@@ -600,24 +504,6 @@ class GteNewModel(BertWithRope):
layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True layer.mlp.gate_up_proj.skip_bias_add = True
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]): def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj" n = "mlp.up_gate_proj"
for name, weight in weights: for name, weight in weights:
...@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel): ...@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
"attention.o_proj": "attn.out_proj", "attention.o_proj": "attn.out_proj",
}) })
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
class JinaRobertaModel(BertWithRope): class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3 # for https://huggingface.co/jinaai/jina-embeddings-v3
...@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope): ...@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln", "norm2": "mlp_ln",
}) })
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class VerifyAndUpdateConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
raise NotImplementedError
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
if config.position_embedding_type == "rotary":
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)
vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
vllm_config.recalculate_max_model_len(max_model_len)
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
is_original_qwen3_reranker = getattr(config,
"is_original_qwen3_reranker",
False)
if not is_original_qwen3_reranker:
return
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
config.num_labels = 1
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig,
}
...@@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, ...@@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
def load_weights_from_original_qwen3_reranker( def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]): self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
self.config.num_labels = 1
model_config = self.vllm_config.model_config model_config = self.vllm_config.model_config
tokens = getattr(self.config, "classifier_from_token", None)
device = self.score.weight.device device = self.score.weight.device
self.score = RowParallelLinear(self.config.hidden_size,
self.config.num_labels,
quant_config=self.quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
self.prefix, "score")).to(device)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
...@@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, ...@@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
self.score.weight.data.copy_(weight) self.score.weight.data.copy_(weight)
del self.lm_head del self.lm_head
loaded_weights.add("classifier.weight") loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight") loaded_weights.discard("lm_head.weight")
return loaded_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment