Unverified Commit 3d3ab368 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[New Model]: Snowflake Arctic Embed (Family) (#16649)

parent 686623c5
...@@ -3,24 +3,17 @@ ...@@ -3,24 +3,17 @@
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
""" """
from typing import NamedTuple
import openai import openai
import pytest import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...models.embedding.utils import EmbedModelInfo
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
class ModelInfo(NamedTuple):
name: str
is_matryoshka: bool
MODELS = [ MODELS = [
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False), EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
] ]
input_texts = [ input_texts = [
...@@ -30,7 +23,7 @@ input_texts = [ ...@@ -30,7 +23,7 @@ input_texts = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
async def test_validating_dimensions(model: ModelInfo): async def test_validating_dimensions(model: EmbedModelInfo):
args = [ args = [
"--task", "--task",
"embed", "embed",
......
# SPDX-License-Identifier: Apache-2.0
"""Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
"""
import pytest
from tests.models.embedding.utils import EmbedModelInfo
from ..utils import check_embeddings_close
EMBEDDING_PROMPTS = [
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
'Mexico City of Course!'
]
MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
is_matryoshka=False,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
is_matryoshka=False,
architecture="NomicBertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
is_matryoshka=True,
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
is_matryoshka=True,
architecture="GteModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model_info: EmbedModelInfo,
dtype: str,
monkeypatch,
) -> None:
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
example_prompts = example_prompts + EMBEDDING_PROMPTS
vllm_extra_kwargs = {
"hf_overrides": {
"is_matryoshka": model_info.is_matryoshka
}
}
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model_info.name,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
model_info.is_matryoshka)
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -37,3 +38,10 @@ def matryoshka_fy(tensor, dimensions): ...@@ -37,3 +38,10 @@ def matryoshka_fy(tensor, dimensions):
tensor = tensor[..., :dimensions] tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1) tensor = F.normalize(tensor, p=2, dim=1)
return tensor return tensor
class EmbedModelInfo(NamedTuple):
name: str
is_matryoshka: bool
architecture: str = ""
enable_test: bool = True
...@@ -247,11 +247,15 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -247,11 +247,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True), trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
......
...@@ -354,6 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: ...@@ -354,6 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(), "gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(), "silu": lambda: SiluAndMul(),
"gelu_and_mul": lambda: GeluAndMul(),
}) })
......
...@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
...@@ -108,6 +110,7 @@ class BertEncoder(nn.Module): ...@@ -108,6 +110,7 @@ class BertEncoder(nn.Module):
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
...@@ -118,6 +121,7 @@ class BertEncoder(nn.Module): ...@@ -118,6 +121,7 @@ class BertEncoder(nn.Module):
BertLayer(config=config, BertLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
...@@ -139,6 +143,7 @@ class BertLayer(nn.Module): ...@@ -139,6 +143,7 @@ class BertLayer(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
...@@ -149,19 +154,31 @@ class BertLayer(nn.Module): ...@@ -149,19 +154,31 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate( if config.hidden_act in ["silu", "gelu_and_mul"]:
hidden_size=config.hidden_size, self.intermediate = BertGatedIntermediate(
intermediate_size=config.intermediate_size, hidden_size=config.hidden_size,
hidden_act=config.hidden_act, intermediate_size=config.intermediate_size,
quant_config=quant_config, hidden_act=config.hidden_act,
prefix=f"{prefix}.intermediate") bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
else:
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size, self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
...@@ -181,6 +198,7 @@ class BertAttention(nn.Module): ...@@ -181,6 +198,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float, layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -190,11 +208,13 @@ class BertAttention(nn.Module): ...@@ -190,11 +208,13 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size, self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
...@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module): ...@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int, num_attention_heads: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module): ...@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj") prefix=f"{prefix}.qkv_proj")
...@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module): ...@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size, self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size, output_size=hidden_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
...@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module): ...@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size, self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size, output_size=intermediate_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act) self.intermediate_act_fn = get_act_fn(hidden_act)
...@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module): ...@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module):
return hidden_states return hidden_states
class BertGatedIntermediate(nn.Module):
# for NomciBert and GteModel
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_and_mul_fn(hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
return hidden_states
class BertOutput(nn.Module): class BertOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size, self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size, output_size=hidden_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
...@@ -343,19 +393,32 @@ class BertOutput(nn.Module): ...@@ -343,19 +393,32 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} packed_modules_mapping = {
"qkv_proj": ["query", "key", "value"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, def __init__(self,
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
embedding_class: type = BertEmbedding, embedding_class: type = BertEmbedding,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False): add_pooling_layer: bool = False):
super().__init__() super().__init__()
"""
For BertModel, all linear layers have bias.
For NomicBertModel, all linear layers do not have bias.
"""
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None self.pooler = BertPooler(config) if add_pooling_layer else None
...@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"), ("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"), ("qkv_proj", "value", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
class NomicBertEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"layers": "layer",
"attn.Wqkv": "attention.self.qkv_proj",
"attn.out_proj": "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc11': "intermediate.up_proj",
'mlp.fc12': "intermediate.gate_proj",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function == "swiglu"
# Assume NomicBertModel all linear layers do not have bias
assert not config.mlp_fc1_bias
assert not config.mlp_fc2_bias
assert not config.qkv_proj_bias
config.layer_norm_eps = config.layer_norm_epsilon
config.position_embedding_type = "rotary"
config.intermediate_size = config.n_inner
config.hidden_act = "silu"
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_trained_positions,
"base": config.rotary_emb_base,
"rope_scaling": {
"rope_type": "dynamic",
"factor": config.rotary_scaling_factor
}
}
return BertModel(vllm_config=vllm_config,
prefix=prefix,
bias=False,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
class GteEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"attention.qkv_proj": "attention.self.qkv_proj",
"attention.o_proj": "attention.output.dense",
'attn_ln': "attention.output.LayerNorm",
'mlp.down_proj': "output.dense",
'mlp_ln': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
}
model = BertModel(vllm_config=vllm_config,
prefix=prefix,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in model.encoder.layer:
layer.intermediate.gate_up_proj.bias = None
layer.intermediate.skip_bias_add = True
return model
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
up, gate = weight.chunk(2, dim=0)
yield name.replace(n, "intermediate.up_proj"), up
yield name.replace(n, "intermediate.gate_proj"), gate
else:
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = self.split_up_gate_proj(weights)
self.model.load_weights(weights)
...@@ -122,13 +122,11 @@ _TEXT_GENERATION_MODELS = { ...@@ -122,13 +122,11 @@ _TEXT_GENERATION_MODELS = {
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
# [Text-only] # [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert", "GteEmbeddingModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
...@@ -138,12 +136,16 @@ _EMBEDDING_MODELS = { ...@@ -138,12 +136,16 @@ _EMBEDDING_MODELS = {
if arch == "LlamaForCausalLM" if arch == "LlamaForCausalLM"
}, },
"MistralModel": ("llama", "LlamaForCausalLM"), "MistralModel": ("llama", "LlamaForCausalLM"),
"NomicBertModel": ("bert", "NomicBertEmbeddingModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment