Unverified Commit 63ad6222 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[New Model]: support GTE NewModel (#17986)

parent e7ef61c1
...@@ -701,12 +701,22 @@ Specified using `--task embed`. ...@@ -701,12 +701,22 @@ Specified using `--task embed`.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `GteModel` - * `GteModel`
* GteModel * Arctic-Embed-2.0-M
* `Snowflake/snowflake-arctic-embed-m-v2.0`. * `Snowflake/snowflake-arctic-embed-m-v2.0`.
* *
* *
- * `GteNewModel`
* mGTE-TRM (see note)
* `Alibaba-NLP/gte-multilingual-base`, etc.
*
*
- * `ModernBertModel`
* ModernBERT-based
* `Alibaba-NLP/gte-modernbert-base`, etc.
*
*
- * `NomicBertModel` - * `NomicBertModel`
* NomicBertModel * Nomic BERT
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. * `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
* *
* *
...@@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform ...@@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
::: :::
:::{note}
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
:::
If your model is not in the above list, we will try to automatically convert the model using If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token. of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
# Most models on the STS12 task (See #17175): # Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype # - Model implementation and minor changes in tensor dtype
...@@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks): ...@@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks):
return run_mteb_embed_task(model, tasks) return run_mteb_embed_task(model, tasks)
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo): def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None):
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
vllm_extra_kwargs = vllm_extra_kwargs or {}
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
max_model_len=None, max_model_len=None,
dtype=model_info.dtype) as vllm_model: dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model:
if model_info.architecture: if model_info.architecture:
assert (model_info.architecture assert (model_info.architecture
...@@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo): ...@@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype) vllm_dtype)
with hf_runner(model_info.name, with set_default_torch_dtype(model_dtype) and hf_runner(
is_sentence_transformer=True, model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model: dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score) print("VLLM:", vllm_dtype, vllm_main_score)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=False),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
...@@ -23,6 +23,7 @@ MODELS = [ ...@@ -23,6 +23,7 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner, def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None: model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
...@@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, ...@@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
if not model_info.enable_test: if not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
dtype=model_info.dtype, dtype=model_info.dtype,
......
...@@ -46,6 +46,7 @@ def test_models_mteb( ...@@ -46,6 +46,7 @@ def test_models_mteb(
vllm_runner, vllm_runner,
model_info: EmbedModelInfo, model_info: EmbedModelInfo,
) -> None: ) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
...@@ -60,6 +61,9 @@ def test_models_correctness( ...@@ -60,6 +61,9 @@ def test_models_correctness(
if not model_info.enable_test: if not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
dtype=model_info.dtype, dtype=model_info.dtype,
......
...@@ -256,11 +256,17 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -256,11 +256,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True), trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
......
...@@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: ...@@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(), "gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(), "silu": lambda: SiluAndMul(),
"gelu_and_mul": lambda: GeluAndMul(), "geglu": lambda: GeluAndMul(),
}) })
......
...@@ -456,6 +456,40 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -456,6 +456,40 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
return self._scaling_factor_to_offset return self._scaling_factor_to_offset
class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
mixed_b: Optional[float] = None) -> None:
self.scaling_factor = scaling_factor
self.mixed_b = mixed_b
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base)
if self.mixed_b is None:
inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
else:
a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
2)**self.mixed_b
lambda_1_m = (a * torch.arange(
1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
inv_freq = inv_freq / lambda_1_m
return inv_freq
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. """RotaryEmbedding extended with Dynamic NTK scaling.
...@@ -1765,6 +1799,14 @@ def get_rope( ...@@ -1765,6 +1799,14 @@ def get_rope(
max_position, base, max_position, base,
is_neox_style, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
......
...@@ -32,11 +32,18 @@ class BertWithRopeEmbedding(nn.Module): ...@@ -32,11 +32,18 @@ class BertWithRopeEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
assert config.type_vocab_size > 0 if config.position_embedding_type not in ["rope", "rotary"]:
raise ValueError("Only 'rotary'('rope') position_embedding_type" +
" is supported")
self.word_embeddings = VocabParallelEmbedding(config.vocab_size, self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding( if config.type_vocab_size > 0:
config.type_vocab_size, config.hidden_size) self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
else:
self.token_type_embeddings = None
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -47,13 +54,17 @@ class BertWithRopeEmbedding(nn.Module): ...@@ -47,13 +54,17 @@ class BertWithRopeEmbedding(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size() input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds
embeddings = inputs_embeds + token_type_embeddings if self.token_type_embeddings is not None:
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings += token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
return embeddings return embeddings
...@@ -321,7 +332,7 @@ class BertWithRopeBlock(nn.Module): ...@@ -321,7 +332,7 @@ class BertWithRopeBlock(nn.Module):
if moe: if moe:
self.mlp = NomicMoELayer(config=config, ) self.mlp = NomicMoELayer(config=config, )
else: else:
if config.hidden_act in ["silu", "gelu_and_mul"]: if config.hidden_act in ["silu", "geglu"]:
self.mlp = BertWithRopeGatedMLP( self.mlp = BertWithRopeGatedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
...@@ -390,6 +401,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -390,6 +401,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config) self.config = self.config_verify(vllm_config)
self.embeddings = BertWithRopeEmbedding(self.config) self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder( self.encoder = BertWithRopeEncoder(
...@@ -420,7 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -420,7 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "gelu_and_mul"]: if self.config.hidden_act in ["silu", "geglu"]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
...@@ -458,6 +470,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -458,6 +470,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
class NomicBertModel(BertWithRope): class NomicBertModel(BertWithRope):
# for https://huggingface.co/nomic-ai/nomic-bert-2048
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm", "emb_ln": "embeddings.LayerNorm",
...@@ -475,6 +489,9 @@ class NomicBertModel(BertWithRope): ...@@ -475,6 +489,9 @@ class NomicBertModel(BertWithRope):
assert config.__class__.__name__ == "NomicBertConfig" assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"] assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu": if config.activation_function == "swiglu":
config.hidden_act = "silu" config.hidden_act = "silu"
...@@ -512,10 +529,13 @@ class NomicBertModel(BertWithRope): ...@@ -512,10 +529,13 @@ class NomicBertModel(BertWithRope):
return config return config
class GteModel(BertWithRope): class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"layer": 'layers', "new.": "",
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj", "attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj", "attention.o_proj": "attn.out_proj",
}) })
...@@ -523,7 +543,7 @@ class GteModel(BertWithRope): ...@@ -523,7 +543,7 @@ class GteModel(BertWithRope):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
# GteModel only gate_up_proj does not have bias. # GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py # Hack method learned from vllm/model_executor/models/glm.py
for layer in self.encoder.layers: for layer in self.encoder.layers:
layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.bias = None
...@@ -532,12 +552,10 @@ class GteModel(BertWithRope): ...@@ -532,12 +552,10 @@ class GteModel(BertWithRope):
def config_verify(self, vllm_config): def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig" assert config.__class__.__name__ == "NewConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu" assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary" config.hidden_act = "geglu"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = { config.rotary_kwargs = {
...@@ -559,13 +577,52 @@ class GteModel(BertWithRope): ...@@ -559,13 +577,52 @@ class GteModel(BertWithRope):
else: else:
yield name, weight yield name, weight
def ignore_unnecessary_layers(self,
weights: Iterable[Tuple[str, torch.Tensor]]):
for name, weight in weights:
if name.startswith("classifier"):
continue
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
weights = self.ignore_unnecessary_layers(weights)
weights = self.split_up_gate_proj(weights) weights = self.split_up_gate_proj(weights)
return super().load_weights(weights) return super().load_weights(weights)
class SnowflakeGteNewModel(GteNewModel):
# for Snowflake/snowflake-arctic-embed-m-v2.0
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
class JinaRobertaModel(BertWithRope): class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm", "emb_ln": "embeddings.LayerNorm",
...@@ -579,6 +636,9 @@ class JinaRobertaModel(BertWithRope): ...@@ -579,6 +636,9 @@ class JinaRobertaModel(BertWithRope):
def config_verify(self, vllm_config): def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
...@@ -611,6 +671,7 @@ class JinaRobertaModel(BertWithRope): ...@@ -611,6 +671,7 @@ class JinaRobertaModel(BertWithRope):
# This is a temporary solution until we have a better way to handle # This is a temporary solution until we have a better way to handle
scaling = self.config.lora_alpha / self.config.lora_rank scaling = self.config.lora_alpha / self.config.lora_rank
device = self.vllm_config.device_config.device
weights = {name: weight for name, weight in weights} weights = {name: weight for name, weight in weights}
...@@ -628,13 +689,13 @@ class JinaRobertaModel(BertWithRope): ...@@ -628,13 +689,13 @@ class JinaRobertaModel(BertWithRope):
weight_name = name[:-len(o)] weight_name = name[:-len(o)]
if "embeddings" in weight_name: if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float() B = weights[weight_name + a][i].to(device).float()
A = weights[weight_name + b][i].cuda().float() A = weights[weight_name + b][i].to(device).float()
else: else:
B = weights[weight_name + b][i].cuda().float() B = weights[weight_name + b][i].to(device).float()
A = weights[weight_name + a][i].cuda().float() A = weights[weight_name + a][i].to(device).float()
weight = (weights[weight_name + o].cuda() + weight = (weights[weight_name + o].to(device) +
torch.matmul(B, A).view(shape) * scaling) torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype) weight = weight.cpu().to(dtype)
......
...@@ -230,9 +230,12 @@ class ModernBertModel(nn.Module): ...@@ -230,9 +230,12 @@ class ModernBertModel(nn.Module):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
positions: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
position_ids = positions if positions is not None else position_ids
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
......
...@@ -127,7 +127,8 @@ _EMBEDDING_MODELS = { ...@@ -127,7 +127,8 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "GteModel"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
...@@ -137,6 +138,7 @@ _EMBEDDING_MODELS = { ...@@ -137,6 +138,7 @@ _EMBEDDING_MODELS = {
if arch == "LlamaForCausalLM" if arch == "LlamaForCausalLM"
}, },
"MistralModel": ("llama", "LlamaForCausalLM"), "MistralModel": ("llama", "LlamaForCausalLM"),
"ModernBertModel": ("modernbert", "ModernBertModel"),
"NomicBertModel": ("bert_with_rope", "NomicBertModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment