Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d6c6b05
Unverified
Commit
6d6c6b05
authored
Sep 06, 2025
by
wang.yuqi
Committed by
GitHub
Sep 05, 2025
Browse files
[New Model]: google/embeddinggemma-300m (#24318)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
53b19ccd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
73 additions
and
29 deletions
+73
-29
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+16
-2
tests/models/language/pooling/test_st_projector.py
tests/models/language/pooling/test_st_projector.py
+6
-1
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/config/__init__.py
vllm/config/__init__.py
+2
-0
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+17
-15
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+9
-0
vllm/model_executor/models/gemma3.py
vllm/model_executor/models/gemma3.py
+20
-11
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/models/supported_models.md
View file @
6d6c6b05
...
@@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
...
@@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
`BertModel`
<sup>
C
</sup>
| BERT-based |
`BAAI/bge-base-en-v1.5`
,
`Snowflake/snowflake-arctic-embed-xs`
, etc. | | | ✅︎ |
|
`BertModel`
<sup>
C
</sup>
| BERT-based |
`BAAI/bge-base-en-v1.5`
,
`Snowflake/snowflake-arctic-embed-xs`
, etc. | | | ✅︎ |
|
`Gemma2Model`
<sup>
C
</sup>
| Gemma 2-based |
`BAAI/bge-multilingual-gemma2`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Gemma2Model`
<sup>
C
</sup>
| Gemma 2-based |
`BAAI/bge-multilingual-gemma2`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Gemma3TextModel`
<sup>
C
</sup>
| Gemma 3-based |
`google/embeddinggemma-300m`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | ✅︎ |
|
`GteModel`
<sup>
C
</sup>
| Arctic-Embed-2.0-M |
`Snowflake/snowflake-arctic-embed-m-v2.0`
. | | | ✅︎ |
|
`GteModel`
<sup>
C
</sup>
| Arctic-Embed-2.0-M |
`Snowflake/snowflake-arctic-embed-m-v2.0`
. | | | ✅︎ |
|
`GteNewModel`
<sup>
C
</sup>
| mGTE-TRM (see note) |
`Alibaba-NLP/gte-multilingual-base`
, etc. | | | ✅︎ |
|
`GteNewModel`
<sup>
C
</sup>
| mGTE-TRM (see note) |
`Alibaba-NLP/gte-multilingual-base`
, etc. | | | ✅︎ |
...
...
tests/models/language/pooling/mteb_utils.py
View file @
6d6c6b05
...
@@ -10,7 +10,8 @@ import numpy as np
...
@@ -10,7 +10,8 @@ import numpy as np
import
pytest
import
pytest
import
requests
import
requests
from
tests.models.utils
import
EmbedModelInfo
,
RerankModelInfo
from
tests.models.utils
import
(
EmbedModelInfo
,
RerankModelInfo
,
check_embeddings_close
)
# Most embedding models on the STS12 task (See #17175):
# Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# - Model implementation and minor changes in tensor dtype
...
@@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
...
@@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
model_info
:
EmbedModelInfo
,
model_info
:
EmbedModelInfo
,
vllm_extra_kwargs
=
None
,
vllm_extra_kwargs
=
None
,
hf_model_callback
=
None
,
hf_model_callback
=
None
,
atol
=
MTEB_
RERANK
_TOL
):
atol
=
MTEB_
EMBED
_TOL
):
if
not
model_info
.
enable_test
:
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# A model family has many models with the same architecture,
# and we don't need to test each one.
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
pytest
.
skip
(
"Skipping test."
)
example_prompts
=
[
"The chef prepared a delicious meal."
]
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
...
@@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
...
@@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
llm
.
llm_engine
.
model_config
.
dtype
vllm_dtype
=
vllm_model
.
llm
.
llm_engine
.
model_config
.
dtype
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
)
if
model_info
.
mteb_score
is
None
:
if
model_info
.
mteb_score
is
None
:
with
hf_runner
(
model_info
.
name
,
with
hf_runner
(
model_info
.
name
,
...
@@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
...
@@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
st_dtype
=
next
(
hf_model
.
model
.
parameters
()).
dtype
st_dtype
=
next
(
hf_model
.
model
.
parameters
()).
dtype
# Test embed_dims and whether to use normalize
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
else
:
else
:
st_main_score
=
model_info
.
mteb_score
st_main_score
=
model_info
.
mteb_score
st_dtype
=
"Constant"
st_dtype
=
"Constant"
...
...
tests/models/language/pooling/test_st_projector.py
View file @
6d6c6b05
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
pytest
from
...utils
import
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
from
...utils
import
(
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
,
LASTPoolingEmbedModelInfo
)
from
.mteb_utils
import
mteb_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
# ST models with projector (Dense) layers
# ST models with projector (Dense) layers
...
@@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [
...
@@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [
mteb_score
=
0.688611955
,
mteb_score
=
0.688611955
,
enable_test
=
True
,
enable_test
=
True
,
),
),
LASTPoolingEmbedModelInfo
(
"google/embeddinggemma-300m"
,
architecture
=
"Gemma3TextModel"
,
mteb_score
=
0.7473819294684156
,
enable_test
=
True
)
]
]
...
...
tests/models/registry.py
View file @
6d6c6b05
...
@@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
# [Text-only]
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
# noqa: E501
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
),
# noqa: E501
"Gemma3TextModel"
:
_HfExamplesInfo
(
"google/embeddinggemma-300m"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GteModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
"GteModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
...
vllm/config/__init__.py
View file @
6d6c6b05
...
@@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
_FLOAT16_NOT_SUPPORTED_MODELS
=
{
_FLOAT16_NOT_SUPPORTED_MODELS
=
{
"gemma2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"gemma2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"gemma3"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"gemma3"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"gemma3_text"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"plamo2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"plamo2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"glm4"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"glm4"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
}
}
...
...
vllm/model_executor/models/adapters.py
View file @
6d6c6b05
...
@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
...
@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
if
not
dense_modules
:
if
not
dense_modules
:
return
None
return
None
module
=
dense_modules
[
0
]
layers
=
[]
folder
=
module
.
get
(
"path"
,
""
)
for
module
in
dense_modules
:
folder
=
module
.
get
(
"path"
,
""
)
config_path
=
f
"
{
folder
}
/config.json"
if
folder
else
"config.json"
layer_config
=
get_hf_file_to_dict
(
config_path
,
model_config
.
model
,
model_config
.
revision
)
if
not
layer_config
:
continue
config_path
=
f
"
{
folder
}
/config.json"
if
folder
else
"config.json"
linear
=
nn
.
Linear
(
layer_config
.
get
(
"in_features"
,
768
),
layer_config
=
get_hf_file_to_dict
(
config_path
,
model_config
.
model
,
layer_config
.
get
(
"out_features"
,
768
),
model_config
.
revision
)
bias
=
layer_config
.
get
(
"bias"
,
True
),
if
not
layer_config
:
dtype
=
torch
.
float32
)
return
None
linear
=
nn
.
Linear
(
layer_config
.
get
(
"in_features"
,
768
),
if
not
_load_dense_weights
(
linear
,
folder
,
model_config
):
layer_config
.
get
(
"out_features"
,
768
),
continue
bias
=
layer_config
.
get
(
"bias"
,
True
),
dtype
=
torch
.
float32
)
if
_load_dense_weights
(
linear
,
folder
,
model_config
):
layers
.
append
(
linear
)
layers
=
[
linear
]
if
act_name
:
=
layer_config
.
get
(
"activation_function"
):
if
act_name
:
=
layer_config
.
get
(
"activation_function"
):
layers
.
append
(
get_act_fn
(
act_name
))
layers
.
append
(
get_act_fn
(
act_name
))
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
torch
.
float32
)
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
torch
.
float32
)
except
Exception
:
except
Exception
:
logger
.
exception
(
"ST projector loading failed"
)
logger
.
exception
(
"ST projector loading failed"
)
...
...
vllm/model_executor/models/config.py
View file @
6d6c6b05
...
@@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
...
@@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
raise
NotImplementedError
raise
NotImplementedError
class
Gemma3TextModelConfig
:
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
.
is_causal
=
not
hf_config
.
use_bidirectional_attention
class
GteNewModelConfig
(
VerifyAndUpdateConfig
):
class
GteNewModelConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
@
staticmethod
...
@@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
...
@@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteNewModel"
:
GteNewModelConfig
,
"GteNewModel"
:
GteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
"Gemma3TextModel"
:
Gemma3TextModelConfig
,
"NomicBertModel"
:
NomicBertModelConfig
,
"NomicBertModel"
:
NomicBertModelConfig
,
"Qwen2ForProcessRewardModel"
:
Qwen2ForProcessRewardModelConfig
,
"Qwen2ForProcessRewardModel"
:
Qwen2ForProcessRewardModelConfig
,
"Qwen2ForRewardModel"
:
Qwen2ForRewardModelConfig
,
"Qwen2ForRewardModel"
:
Qwen2ForRewardModelConfig
,
...
...
vllm/model_executor/models/gemma3.py
View file @
6d6c6b05
...
@@ -24,7 +24,7 @@ import torch.nn.functional as F
...
@@ -24,7 +24,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Gemma3TextConfig
from
transformers
import
Gemma3TextConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
...attention.layers.encoder_only_attention
import
EncoderOnlyAttention
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
...
@@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
...
@@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
rope_scaling
=
self
.
rope_scaling
,
rope_scaling
=
self
.
rope_scaling
,
)
)
# Initialize the attention.
if
getattr
(
config
,
"is_causal"
,
True
):
self
.
attn
=
Attention
(
self
.
num_heads
,
attn_type
=
AttentionType
.
DECODER
self
.
head_dim
,
else
:
self
.
scaling
,
attn_type
=
AttentionType
.
ENCODER_ONLY
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
attn_cls
=
(
EncoderOnlyAttention
quant_config
=
quant_config
,
if
attn_type
==
AttentionType
.
ENCODER_ONLY
else
Attention
)
logits_soft_cap
=
attn_logits_soft_cap
,
per_layer_sliding_window
=
sliding_window
,
self
.
attn
=
attn_cls
(
self
.
num_heads
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
attn_type
=
attn_type
,
logits_soft_cap
=
attn_logits_soft_cap
,
per_layer_sliding_window
=
sliding_window
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/registry.py
View file @
6d6c6b05
...
@@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
...
@@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
"BertModel"
:
(
"bert"
,
"BertEmbeddingModel"
),
"BertModel"
:
(
"bert"
,
"BertEmbeddingModel"
),
"DeciLMForCausalLM"
:
(
"nemotron_nas"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"nemotron_nas"
,
"DeciLMForCausalLM"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma3TextModel"
:
(
"gemma3"
,
"Gemma3Model"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GPT2ForSequenceClassification"
:
(
"gpt2"
,
"GPT2ForSequenceClassification"
),
"GPT2ForSequenceClassification"
:
(
"gpt2"
,
"GPT2ForSequenceClassification"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment