Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64ab3c72
Unverified
Commit
64ab3c72
authored
Aug 20, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 20, 2025
Browse files
[Doc] Update V1 status of various pooling models (#23189)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
e58c5a97
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
24 deletions
+28
-24
docs/models/supported_models.md
docs/models/supported_models.md
+13
-13
tests/models/language/pooling/test_gritlm.py
tests/models/language/pooling/test_gritlm.py
+5
-4
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+3
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+7
-4
No files found.
docs/models/supported_models.md
View file @
64ab3c72
...
...
@@ -363,7 +363,7 @@ th {
|
`GraniteMoeForCausalLM`
| Granite 3.0 MoE, PowerMoE |
`ibm-granite/granite-3.0-1b-a400m-base`
,
`ibm-granite/granite-3.0-3b-a800m-instruct`
,
`ibm/PowerMoE-3b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteMoeHybridForCausalLM`
| Granite 4.0 MoE Hybrid |
`ibm-granite/granite-4.0-tiny-preview`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteMoeSharedForCausalLM`
| Granite MoE Shared |
`ibm-research/moe-7b-1b-active-shared-experts`
(test model) | ✅︎ | ✅︎ | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ |
✅︎
|
|
`Grok1ModelForCausalLM`
| Grok1 |
`hpcai-tech/grok-1`
. | ✅︎ | ✅︎ | ✅︎ |
|
`HunYuanDenseV1ForCausalLM`
| Hunyuan-7B-Instruct-0124 |
`tencent/Hunyuan-7B-Instruct-0124`
| ✅︎ | | ✅︎ |
|
`HunYuanMoEV1ForCausalLM`
| Hunyuan-80B-A13B |
`tencent/Hunyuan-A13B-Instruct`
,
`tencent/Hunyuan-A13B-Pretrain`
,
`tencent/Hunyuan-A13B-Instruct-FP8`
, etc. | ✅︎ | | ✅︎ |
...
...
@@ -436,17 +436,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
| Architecture | Models | Example HF Models |
[
LoRA
](
../features/lora.md
)
|
[
PP
](
../serving/parallelism_scaling.md
)
|
[
V1
](
gh-issue:8779
)
|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
`BertModel`
<sup>
C
</sup>
| BERT-based |
`BAAI/bge-base-en-v1.5`
,
`Snowflake/snowflake-arctic-embed-xs`
, etc. | | | |
|
`Gemma2Model`
<sup>
C
</sup>
| Gemma 2-based |
`BAAI/bge-multilingual-gemma2`
, etc. | ✅︎ | | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | |
|
`GteModel`
<sup>
C
</sup>
| Arctic-Embed-2.0-M |
`Snowflake/snowflake-arctic-embed-m-v2.0`
. | | | |
|
`GteNewModel`
<sup>
C
</sup>
| mGTE-TRM (see note) |
`Alibaba-NLP/gte-multilingual-base`
, etc. | | | |
|
`ModernBertModel`
<sup>
C
</sup>
| ModernBERT-based |
`Alibaba-NLP/gte-modernbert-base`
, etc. | | | |
|
`NomicBertModel`
<sup>
C
</sup>
| Nomic BERT |
`nomic-ai/nomic-embed-text-v1`
,
`nomic-ai/nomic-embed-text-v2-moe`
,
`Snowflake/snowflake-arctic-embed-m-long`
, etc. | | | |
|
`BertModel`
<sup>
C
</sup>
| BERT-based |
`BAAI/bge-base-en-v1.5`
,
`Snowflake/snowflake-arctic-embed-xs`
, etc. | | |
✅︎
|
|
`Gemma2Model`
<sup>
C
</sup>
| Gemma 2-based |
`BAAI/bge-multilingual-gemma2`
, etc. | ✅︎ |
✅︎
| ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ |
✅︎
|
|
`GteModel`
<sup>
C
</sup>
| Arctic-Embed-2.0-M |
`Snowflake/snowflake-arctic-embed-m-v2.0`
. | | |
✅︎
|
|
`GteNewModel`
<sup>
C
</sup>
| mGTE-TRM (see note) |
`Alibaba-NLP/gte-multilingual-base`
, etc. | | |
✅︎
|
|
`ModernBertModel`
<sup>
C
</sup>
| ModernBERT-based |
`Alibaba-NLP/gte-modernbert-base`
, etc. | | |
✅︎
|
|
`NomicBertModel`
<sup>
C
</sup>
| Nomic BERT |
`nomic-ai/nomic-embed-text-v1`
,
`nomic-ai/nomic-embed-text-v2-moe`
,
`Snowflake/snowflake-arctic-embed-m-long`
, etc. | | |
✅︎
|
|
`LlamaModel`
<sup>
C
</sup>
,
`LlamaForCausalLM`
<sup>
C
</sup>
,
`MistralModel`
<sup>
C
</sup>
, etc. | Llama-based |
`intfloat/e5-mistral-7b-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen2Model`
<sup>
C
</sup>
,
`Qwen2ForCausalLM`
<sup>
C
</sup>
| Qwen2-based |
`ssmits/Qwen2-7B-Instruct-embed-base`
(see note),
`Alibaba-NLP/gte-Qwen2-7B-instruct`
(see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3Model`
<sup>
C
</sup>
,
`Qwen3ForCausalLM`
<sup>
C
</sup>
| Qwen3-based |
`Qwen/Qwen3-Embedding-0.6B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`RobertaModel`
,
`RobertaForMaskedLM`
| RoBERTa-based |
`sentence-transformers/all-roberta-large-v1`
, etc. | | | |
|
`RobertaModel`
,
`RobertaForMaskedLM`
| RoBERTa-based |
`sentence-transformers/all-roberta-large-v1`
, etc. | | |
✅︎
|
|
`*Model`
<sup>
C
</sup>
,
`*ForCausalLM`
<sup>
C
</sup>
, etc. | Generative models | N/A |
\*
|
\*
|
\*
|
<sup>
C
</sup>
Automatically converted into an embedding model via
`--convert embed`
. (
[
details
](
./pooling_models.md#model-conversion
)
)
...
...
@@ -476,7 +476,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
| Architecture | Models | Example HF Models |
[
LoRA
](
../features/lora.md
)
|
[
PP
](
../serving/parallelism_scaling.md
)
|
[
V1
](
gh-issue:8779
)
|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
`JambaForSequenceClassification`
| Jamba |
`ai21labs/Jamba-tiny-reward-dev`
, etc. | ✅︎ | ✅︎ | |
|
`JambaForSequenceClassification`
| Jamba |
`ai21labs/Jamba-tiny-reward-dev`
, etc. | ✅︎ | ✅︎ |
✅︎
|
|
`GPT2ForSequenceClassification`
| GPT2 |
`nie3e/sentiment-polish-gpt2-small`
| | | ✅︎ |
|
`*Model`
<sup>
C
</sup>
,
`*ForCausalLM`
<sup>
C
</sup>
, etc. | Generative models | N/A |
\*
|
\*
|
\*
|
...
...
@@ -493,12 +493,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
| Architecture | Models | Example HF Models |
[
LoRA
](
../features/lora.md
)
|
[
PP
](
../serving/parallelism_scaling.md
)
|
[
V1
](
gh-issue:8779
)
|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
`BertForSequenceClassification`
| BERT-based |
`cross-encoder/ms-marco-MiniLM-L-6-v2`
, etc. | | | |
|
`BertForSequenceClassification`
| BERT-based |
`cross-encoder/ms-marco-MiniLM-L-6-v2`
, etc. | | |
✅︎
|
|
`GemmaForSequenceClassification`
| Gemma-based |
`BAAI/bge-reranker-v2-gemma`
(see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen2ForSequenceClassification`
| Qwen2-based |
`mixedbread-ai/mxbai-rerank-base-v2`
(see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3ForSequenceClassification`
| Qwen3-based |
`tomaarsen/Qwen3-Reranker-0.6B-seq-cls`
,
`Qwen/Qwen3-Reranker-0.6B`
(see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
`RobertaForSequenceClassification`
| RoBERTa-based |
`cross-encoder/quora-roberta-base`
, etc. | | | |
|
`XLMRobertaForSequenceClassification`
| XLM-RoBERTa-based |
`BAAI/bge-reranker-v2-m3`
, etc. | | | |
|
`RobertaForSequenceClassification`
| RoBERTa-based |
`cross-encoder/quora-roberta-base`
, etc. | | |
✅︎
|
|
`XLMRobertaForSequenceClassification`
| XLM-RoBERTa-based |
`BAAI/bge-reranker-v2-m3`
, etc. | | |
✅︎
|
|
`*Model`
<sup>
C
</sup>
,
`*ForCausalLM`
<sup>
C
</sup>
, etc. | Generative models | N/A |
\*
|
\*
|
\*
|
<sup>
C
</sup>
Automatically converted into a classification model via
`--convert classify`
. (
[
details
](
./pooling_models.md#model-conversion
)
)
...
...
tests/models/language/pooling/test_gritlm.py
View file @
64ab3c72
...
...
@@ -14,6 +14,7 @@ from ....utils import RemoteOpenAIServer
MODEL_NAME
=
"parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN
=
4000
ATOL
=
0.002
def
_arr
(
arr
):
...
...
@@ -97,16 +98,16 @@ def get_test_data():
def
validate_embed_output
(
q_rep
:
list
[
list
[
float
]],
d_rep
:
list
[
list
[
float
]]):
cosine_sim_q0_d0
=
1
-
cosine
(
q_rep
[
0
],
d_rep
[
0
])
assert
cosine_sim_q0_d0
==
pytest
.
approx
(
0.609
,
abs
=
0.001
)
assert
cosine_sim_q0_d0
==
pytest
.
approx
(
0.609
,
abs
=
ATOL
)
cosine_sim_q0_d1
=
1
-
cosine
(
q_rep
[
0
],
d_rep
[
1
])
assert
cosine_sim_q0_d1
==
pytest
.
approx
(
0.101
,
abs
=
0.001
)
assert
cosine_sim_q0_d1
==
pytest
.
approx
(
0.101
,
abs
=
ATOL
)
cosine_sim_q1_d0
=
1
-
cosine
(
q_rep
[
1
],
d_rep
[
0
])
assert
cosine_sim_q1_d0
==
pytest
.
approx
(
0.120
,
abs
=
0.001
)
assert
cosine_sim_q1_d0
==
pytest
.
approx
(
0.120
,
abs
=
ATOL
)
cosine_sim_q1_d1
=
1
-
cosine
(
q_rep
[
1
],
d_rep
[
1
])
assert
cosine_sim_q1_d1
==
pytest
.
approx
(
0.534
,
abs
=
0.001
)
assert
cosine_sim_q1_d1
==
pytest
.
approx
(
0.534
,
abs
=
ATOL
)
def
test_gritlm_offline_embedding
(
vllm_runner
):
...
...
vllm/model_executor/models/gritlm.py
View file @
64ab3c72
...
...
@@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
from
vllm.tasks
import
PoolingTask
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
SupportsV0Only
from
.interfaces
import
default_pooling_type
logger
=
init_logger
(
__name__
)
...
...
@@ -215,7 +215,8 @@ class GritLMPooler(Pooler):
return
build_output
(
pooled_data
)
class
GritLM
(
LlamaForCausalLM
,
SupportsV0Only
):
@
default_pooling_type
(
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling
...
...
@@ -241,7 +242,6 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix
:
str
=
""
,
**
kwargs
,
)
->
None
:
# Use full attention for pooling (this is why V1 is not supported yet)
if
vllm_config
.
model_config
.
runner_type
==
"pooling"
:
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
.
is_causal
=
False
...
...
vllm/model_executor/models/interfaces.py
View file @
64ab3c72
...
...
@@ -3,7 +3,7 @@
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
TypeVar
,
Union
,
overload
,
runtime_checkable
)
import
numpy
as
np
import
torch
...
...
@@ -641,11 +641,14 @@ def supports_cross_encoding(
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
def
default_pooling_type
(
pooling_type
:
str
)
->
object
:
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
torch
.
nn
.
Module
])
def
default_pooling_type
(
pooling_type
:
str
):
"""Set default_pooling_type decorator. """
def
func
(
model
:
object
)
:
model
.
default_pooling_type
=
pooling_type
def
func
(
model
:
_T
)
->
_T
:
model
.
default_pooling_type
=
pooling_type
# type: ignore
return
model
return
func
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment