Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35cf32df
Unverified
Commit
35cf32df
authored
Jun 04, 2025
by
wang.yuqi
Committed by
GitHub
Jun 04, 2025
Browse files
Improve the output precision of embedding models (#19092)
parent
8711bc5e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
69 additions
and
28 deletions
+69
-28
tests/models/language/pooling/embed_utils.py
tests/models/language/pooling/embed_utils.py
+1
-5
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+6
-6
tests/models/language/pooling/test_gte.py
tests/models/language/pooling/test_gte.py
+0
-7
tests/models/language/pooling/test_intfloat.py
tests/models/language/pooling/test_intfloat.py
+46
-0
tests/models/language/pooling/test_jina.py
tests/models/language/pooling/test_jina.py
+1
-2
tests/models/language/pooling/test_nomic.py
tests/models/language/pooling/test_nomic.py
+0
-3
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+9
-4
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+6
-1
No files found.
tests/models/language/pooling/embed_utils.py
View file @
35cf32df
...
@@ -56,14 +56,10 @@ def correctness_test_embed_models(hf_runner,
...
@@ -56,14 +56,10 @@ def correctness_test_embed_models(hf_runner,
max_model_len
=
None
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
**
vllm_extra_kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
model_dtype
=
getattr
(
vllm_model
.
model
.
llm_engine
.
model_config
.
hf_config
,
"torch_dtype"
,
vllm_dtype
)
with
hf_runner
(
with
hf_runner
(
model_info
.
name
,
model_info
.
name
,
dtype
=
model_dtype
,
dtype
=
"float32"
,
is_sentence_transformer
=
True
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
)
as
hf_model
:
...
...
tests/models/language/pooling/mteb_utils.py
View file @
35cf32df
...
@@ -7,7 +7,6 @@ import numpy as np
...
@@ -7,7 +7,6 @@ import numpy as np
import
pytest
import
pytest
from
tests.models.utils
import
EmbedModelInfo
from
tests.models.utils
import
EmbedModelInfo
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
# Most models on the STS12 task (See #17175):
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# - Model implementation and minor changes in tensor dtype
...
@@ -104,17 +103,18 @@ def mteb_test_embed_models(hf_runner,
...
@@ -104,17 +103,18 @@ def mteb_test_embed_models(hf_runner,
MTEB_EMBED_TASKS
)
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
with
set_default_torch_dtype
(
vllm_dtype
)
and
hf_runner
(
with
hf_runner
(
model_info
.
name
,
model_info
.
name
,
is_sentence_transformer
=
True
,
is_sentence_transformer
=
True
,
dtype
=
vllm_dtype
)
as
hf_model
:
dtype
=
"float32"
)
as
hf_model
:
if
hf_model_callback
is
not
None
:
if
hf_model_callback
is
not
None
:
hf_model_callback
(
hf_model
)
hf_model_callback
(
hf_model
)
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
st_dtype
=
next
(
hf_model
.
model
.
parameters
()).
dtype
print
(
"VLLM:"
,
vllm_main_score
)
print
(
"VLLM:"
,
vllm_dtype
,
vllm_main_score
)
print
(
"SentenceTransformers:"
,
st_main_score
)
print
(
"SentenceTransformers:"
,
st_dtype
,
st_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_EMBED_TOL
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_EMBED_TOL
)
tests/models/language/pooling/test_gte.py
View file @
35cf32df
...
@@ -11,27 +11,21 @@ MODELS = [
...
@@ -11,27 +11,21 @@ MODELS = [
########## BertModel
########## BertModel
EmbedModelInfo
(
"thenlper/gte-large"
,
EmbedModelInfo
(
"thenlper/gte-large"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
True
),
enable_test
=
True
),
EmbedModelInfo
(
"thenlper/gte-base"
,
EmbedModelInfo
(
"thenlper/gte-base"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-small"
,
EmbedModelInfo
(
"thenlper/gte-small"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-large-zh"
,
EmbedModelInfo
(
"thenlper/gte-large-zh"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-base-zh"
,
EmbedModelInfo
(
"thenlper/gte-base-zh"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-small-zh"
,
EmbedModelInfo
(
"thenlper/gte-small-zh"
,
architecture
=
"BertModel"
,
architecture
=
"BertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
########### NewModel
########### NewModel
EmbedModelInfo
(
"Alibaba-NLP/gte-multilingual-base"
,
EmbedModelInfo
(
"Alibaba-NLP/gte-multilingual-base"
,
...
@@ -46,7 +40,6 @@ MODELS = [
...
@@ -46,7 +40,6 @@ MODELS = [
########### Qwen2ForCausalLM
########### Qwen2ForCausalLM
EmbedModelInfo
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
EmbedModelInfo
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
architecture
=
"Qwen2ForCausalLM"
,
architecture
=
"Qwen2ForCausalLM"
,
dtype
=
"float32"
,
enable_test
=
True
),
enable_test
=
True
),
########## ModernBertModel
########## ModernBertModel
EmbedModelInfo
(
"Alibaba-NLP/gte-modernbert-base"
,
EmbedModelInfo
(
"Alibaba-NLP/gte-modernbert-base"
,
...
...
tests/models/language/pooling/test_intfloat.py
0 → 100644
View file @
35cf32df
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
...utils
import
EmbedModelInfo
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
MODELS
=
[
########## BertModel
EmbedModelInfo
(
"intfloat/e5-small"
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"intfloat/e5-base"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/e5-large"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/multilingual-e5-small"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
########## XLMRobertaModel
EmbedModelInfo
(
"intfloat/multilingual-e5-base"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"intfloat/multilingual-e5-large"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/multilingual-e5-large-instruct"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
False
),
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_embed_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
)
->
None
:
mteb_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_embed_models_correctness
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
example_prompts
)
->
None
:
correctness_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
,
example_prompts
)
tests/models/language/pooling/test_jina.py
View file @
35cf32df
...
@@ -32,8 +32,7 @@ TEXTS_2 = [
...
@@ -32,8 +32,7 @@ TEXTS_2 = [
EMBEDDING_MODELS
=
[
EMBEDDING_MODELS
=
[
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
architecture
=
"XLMRobertaModel"
,
architecture
=
"XLMRobertaModel"
,
is_matryoshka
=
True
,
is_matryoshka
=
True
)
dtype
=
"float32"
)
]
]
...
...
tests/models/language/pooling/test_nomic.py
View file @
35cf32df
...
@@ -9,18 +9,15 @@ from .mteb_utils import mteb_test_embed_models
...
@@ -9,18 +9,15 @@ from .mteb_utils import mteb_test_embed_models
MODELS
=
[
MODELS
=
[
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
,
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
,
architecture
=
"NomicBertModel"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
True
),
enable_test
=
True
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1.5"
,
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1.5"
,
architecture
=
"NomicBertModel"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"nomic-ai/CodeRankEmbed"
,
EmbedModelInfo
(
"nomic-ai/CodeRankEmbed"
,
architecture
=
"NomicBertModel"
,
architecture
=
"NomicBertModel"
,
enable_test
=
False
),
enable_test
=
False
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
,
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
,
architecture
=
"NomicBertModel"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
True
)
enable_test
=
True
)
]
]
...
...
vllm/model_executor/models/bert.py
View file @
35cf32df
...
@@ -414,10 +414,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -414,10 +414,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
return
hidden_states
def
pooler
(
def
pooler
(
self
,
self
,
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
35cf32df
...
@@ -432,7 +432,12 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -432,7 +432,12 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
else
:
else
:
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
positions
,
hidden_states
)
hidden_states
=
self
.
encoder
(
positions
,
hidden_states
)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment