Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f40f763f
"docs/vscode:/vscode.git/clone" did not exist on "6e2c19ce227ecf285ed24a138b91570b3a2d57a6"
Unverified
Commit
f40f763f
authored
Jun 16, 2025
by
wang.yuqi
Committed by
GitHub
Jun 16, 2025
Browse files
[CI] Add mteb testing for rerank models (#19344)
parent
26bc46ef
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
428 additions
and
256 deletions
+428
-256
requirements/test.in
requirements/test.in
+1
-1
requirements/test.txt
requirements/test.txt
+6
-0
tests/conftest.py
tests/conftest.py
+9
-3
tests/entrypoints/openai/correctness/test_mteb_embed.py
tests/entrypoints/openai/correctness/test_mteb_embed.py
+6
-10
tests/entrypoints/openai/correctness/test_mteb_score.py
tests/entrypoints/openai/correctness/test_mteb_score.py
+59
-0
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+183
-9
tests/models/language/pooling/test_baai.py
tests/models/language/pooling/test_baai.py
+23
-2
tests/models/language/pooling/test_cross_encoder.py
tests/models/language/pooling/test_cross_encoder.py
+18
-0
tests/models/language/pooling/test_jina.py
tests/models/language/pooling/test_jina.py
+16
-62
tests/models/language/pooling/test_qwen3_reranker.py
tests/models/language/pooling/test_qwen3_reranker.py
+84
-80
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
+0
-73
tests/models/utils.py
tests/models/utils.py
+7
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+11
-1
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+4
-9
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+1
-6
No files found.
requirements/test.in
View file @
f40f763f
...
@@ -33,7 +33,7 @@ num2words # required for smolvlm test
...
@@ -33,7 +33,7 @@ num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
lm-eval[api]==0.4.8 # required for model evaluation test
mteb>=1.38.11, <2 # required for mteb test
mteb
[bm25s]
>=1.38.11, <2 # required for mteb test
transformers==4.52.4
transformers==4.52.4
tokenizers==0.21.1
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
...
...
requirements/test.txt
View file @
f40f763f
...
@@ -51,6 +51,8 @@ black==24.10.0
...
@@ -51,6 +51,8 @@ black==24.10.0
# via datamodel-code-generator
# via datamodel-code-generator
blobfile==3.0.0
blobfile==3.0.0
# via -r requirements/test.in
# via -r requirements/test.in
bm25s==0.2.13
# via mteb
boto3==1.35.57
boto3==1.35.57
# via tensorizer
# via tensorizer
botocore==1.35.57
botocore==1.35.57
...
@@ -344,6 +346,7 @@ numpy==1.26.4
...
@@ -344,6 +346,7 @@ numpy==1.26.4
# -r requirements/test.in
# -r requirements/test.in
# accelerate
# accelerate
# bitsandbytes
# bitsandbytes
# bm25s
# contourpy
# contourpy
# cupy-cuda12x
# cupy-cuda12x
# datasets
# datasets
...
@@ -534,6 +537,8 @@ pyparsing==3.2.0
...
@@ -534,6 +537,8 @@ pyparsing==3.2.0
# via matplotlib
# via matplotlib
pyrate-limiter==3.7.0
pyrate-limiter==3.7.0
# via schemathesis
# via schemathesis
pystemmer==3.0.0
# via mteb
pytablewriter==1.2.0
pytablewriter==1.2.0
# via lm-eval
# via lm-eval
pytest==8.3.3
pytest==8.3.3
...
@@ -668,6 +673,7 @@ scikit-learn==1.5.2
...
@@ -668,6 +673,7 @@ scikit-learn==1.5.2
# sentence-transformers
# sentence-transformers
scipy==1.13.1
scipy==1.13.1
# via
# via
# bm25s
# librosa
# librosa
# mteb
# mteb
# scikit-learn
# scikit-learn
...
...
tests/conftest.py
View file @
f40f763f
...
@@ -727,8 +727,12 @@ class HfRunner:
...
@@ -727,8 +727,12 @@ class HfRunner:
**
kwargs
)
->
list
[
list
[
torch
.
Tensor
]]:
**
kwargs
)
->
list
[
list
[
torch
.
Tensor
]]:
return
self
.
model
.
encode
(
prompts
,
*
args
,
**
kwargs
)
return
self
.
model
.
encode
(
prompts
,
*
args
,
**
kwargs
)
def
predict
(
self
,
prompts
:
list
[
list
[
str
]])
->
torch
.
Tensor
:
def
predict
(
self
,
prompts
:
list
[
list
[
str
]],
*
args
,
return
self
.
model
.
predict
(
prompts
,
convert_to_tensor
=
True
)
**
kwargs
)
->
torch
.
Tensor
:
return
self
.
model
.
predict
(
prompts
,
*
args
,
convert_to_tensor
=
True
,
**
kwargs
)
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
@@ -1037,8 +1041,10 @@ class VllmRunner:
...
@@ -1037,8 +1041,10 @@ class VllmRunner:
self
,
self
,
text_1
:
Union
[
str
,
list
[
str
]],
text_1
:
Union
[
str
,
list
[
str
]],
text_2
:
Union
[
str
,
list
[
str
]],
text_2
:
Union
[
str
,
list
[
str
]],
*
args
,
**
kwargs
,
)
->
list
[
float
]:
)
->
list
[
float
]:
req_outputs
=
self
.
model
.
score
(
text_1
,
text_2
)
req_outputs
=
self
.
model
.
score
(
text_1
,
text_2
,
*
args
,
**
kwargs
)
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
...
...
tests/entrypoints/openai/correctness/test_mteb.py
→
tests/entrypoints/openai/correctness/test_mteb
_embed
.py
View file @
f40f763f
...
@@ -7,34 +7,30 @@ import pytest
...
@@ -7,34 +7,30 @@ import pytest
from
tests.models.language.pooling.mteb_utils
import
(
MTEB_EMBED_TASKS
,
from
tests.models.language.pooling.mteb_utils
import
(
MTEB_EMBED_TASKS
,
MTEB_EMBED_TOL
,
MTEB_EMBED_TOL
,
OpenAIClientMtebEncoder
,
OpenAIClientMtebEncoder
,
run_mteb_embed_task
,
run_mteb_embed_task
)
run_mteb_embed_task_st
)
from
tests.utils
import
RemoteOpenAIServer
from
tests.utils
import
RemoteOpenAIServer
os
.
environ
[
"VLLM_LOGGING_LEVEL"
]
=
"WARNING"
os
.
environ
[
"VLLM_LOGGING_LEVEL"
]
=
"WARNING"
MODEL_NAME
=
"BAAI/bge-m3"
MODEL_NAME
=
"intfloat/e5-small"
DTYPE
=
"float16"
MAIN_SCORE
=
0.7422994752439667
MAIN_SCORE
=
0.7873427091972599
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
def
server
():
args
=
[
args
=
[
"--task"
,
"embed"
,
"--dtype"
,
DTYPE
,
"--enforce-eager"
,
"--task"
,
"embed"
,
"--enforce-eager"
,
"--disable-uvicorn-access-log"
"--max-model-len"
,
"512"
]
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
yield
remote_server
def
test_mteb
(
server
):
def
test_mteb
_embed
(
server
):
client
=
server
.
get_client
()
client
=
server
.
get_client
()
encoder
=
OpenAIClientMtebEncoder
(
MODEL_NAME
,
client
)
encoder
=
OpenAIClientMtebEncoder
(
MODEL_NAME
,
client
)
vllm_main_score
=
run_mteb_embed_task
(
encoder
,
MTEB_EMBED_TASKS
)
vllm_main_score
=
run_mteb_embed_task
(
encoder
,
MTEB_EMBED_TASKS
)
st_main_score
=
MAIN_SCORE
or
run_mteb_embed_task_st
(
st_main_score
=
MAIN_SCORE
MODEL_NAME
,
MTEB_EMBED_TASKS
)
print
(
"VLLM main score: "
,
vllm_main_score
)
print
(
"VLLM main score: "
,
vllm_main_score
)
print
(
"SentenceTransformer main score: "
,
st_main_score
)
print
(
"SentenceTransformer main score: "
,
st_main_score
)
...
...
tests/entrypoints/openai/correctness/test_mteb_score.py
0 → 100644
View file @
f40f763f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
pytest
# yapf conflicts with isort for this block
# yapf: disable
from
tests.models.language.pooling.mteb_utils
import
(
MTEB_RERANK_LANGS
,
MTEB_RERANK_TASKS
,
MTEB_RERANK_TOL
,
RerankClientMtebEncoder
,
ScoreClientMtebEncoder
,
run_mteb_rerank
)
# yapf: enable
from
tests.utils
import
RemoteOpenAIServer
os
.
environ
[
"VLLM_LOGGING_LEVEL"
]
=
"WARNING"
MODEL_NAME
=
"cross-encoder/ms-marco-MiniLM-L-6-v2"
MAIN_SCORE
=
0.33437
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--task"
,
"score"
,
"--enforce-eager"
,
"--disable-uvicorn-access-log"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
def
test_mteb_score
(
server
):
url
=
server
.
url_for
(
"score"
)
encoder
=
ScoreClientMtebEncoder
(
MODEL_NAME
,
url
)
vllm_main_score
=
run_mteb_rerank
(
encoder
,
MTEB_RERANK_TASKS
,
MTEB_RERANK_LANGS
)
st_main_score
=
MAIN_SCORE
print
(
"VLLM main score: "
,
vllm_main_score
)
print
(
"SentenceTransformer main score: "
,
st_main_score
)
print
(
"Difference: "
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_RERANK_TOL
)
def
test_mteb_rerank
(
server
):
url
=
server
.
url_for
(
"rerank"
)
encoder
=
RerankClientMtebEncoder
(
MODEL_NAME
,
url
)
vllm_main_score
=
run_mteb_rerank
(
encoder
,
MTEB_RERANK_TASKS
,
MTEB_RERANK_LANGS
)
st_main_score
=
MAIN_SCORE
print
(
"VLLM main score: "
,
vllm_main_score
)
print
(
"SentenceTransformer main score: "
,
st_main_score
)
print
(
"Difference: "
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_RERANK_TOL
)
tests/models/language/pooling/mteb_utils.py
View file @
f40f763f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Optional
import
mteb
import
mteb
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
requests
from
tests.models.utils
import
EmbedModelInfo
from
tests.models.utils
import
EmbedModelInfo
,
RerankModelInfo
# Most models on the STS12 task (See #17175):
# Most
embedding
models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# - Model implementation and minor changes in tensor dtype
# results in differences less than 1e-4
# results in differences less than 1e-4
# - Different model results in differences more than 1e-3
# - Different model results in differences more than 1e-3
...
@@ -16,6 +20,11 @@ from tests.models.utils import EmbedModelInfo
...
@@ -16,6 +20,11 @@ from tests.models.utils import EmbedModelInfo
MTEB_EMBED_TASKS
=
[
"STS12"
]
MTEB_EMBED_TASKS
=
[
"STS12"
]
MTEB_EMBED_TOL
=
1e-4
MTEB_EMBED_TOL
=
1e-4
# See #19344
MTEB_RERANK_TASKS
=
[
"NFCorpus"
]
MTEB_RERANK_LANGS
=
[
"en"
]
MTEB_RERANK_TOL
=
1e-3
class
VllmMtebEncoder
(
mteb
.
Encoder
):
class
VllmMtebEncoder
(
mteb
.
Encoder
):
...
@@ -39,6 +48,27 @@ class VllmMtebEncoder(mteb.Encoder):
...
@@ -39,6 +48,27 @@ class VllmMtebEncoder(mteb.Encoder):
embeds
=
embeds
[
np
.
argsort
(
r
)]
embeds
=
embeds
[
np
.
argsort
(
r
)]
return
embeds
return
embeds
def
predict
(
self
,
sentences
:
list
[
tuple
[
str
,
str
,
Optional
[
str
]]],
# query, corpus, prompt
*
args
,
**
kwargs
,
)
->
np
.
ndarray
:
r
=
self
.
rng
.
permutation
(
len
(
sentences
))
sentences
=
[
sentences
[
i
]
for
i
in
r
]
queries
=
[
s
[
0
]
for
s
in
sentences
]
corpus
=
[
s
[
1
]
for
s
in
sentences
]
outputs
=
self
.
model
.
score
(
queries
,
corpus
,
truncate_prompt_tokens
=-
1
,
use_tqdm
=
False
)
scores
=
np
.
array
(
outputs
)
scores
=
scores
[
np
.
argsort
(
r
)]
return
scores
class
OpenAIClientMtebEncoder
(
mteb
.
Encoder
):
class
OpenAIClientMtebEncoder
(
mteb
.
Encoder
):
...
@@ -62,21 +92,72 @@ class OpenAIClientMtebEncoder(mteb.Encoder):
...
@@ -62,21 +92,72 @@ class OpenAIClientMtebEncoder(mteb.Encoder):
return
embeds
return
embeds
class
ScoreClientMtebEncoder
(
mteb
.
Encoder
):
def
__init__
(
self
,
model_name
:
str
,
url
):
super
().
__init__
()
self
.
model_name
=
model_name
self
.
url
=
url
self
.
rng
=
np
.
random
.
default_rng
(
seed
=
42
)
def
predict
(
self
,
sentences
:
list
[
tuple
[
str
,
str
,
Optional
[
str
]]],
# query, corpus, prompt
*
args
,
**
kwargs
,
)
->
np
.
ndarray
:
r
=
self
.
rng
.
permutation
(
len
(
sentences
))
sentences
=
[
sentences
[
i
]
for
i
in
r
]
outputs
=
[]
for
query
,
corpus
,
prompt
in
sentences
:
outputs
.
append
(
self
.
get_score
(
query
,
corpus
))
scores
=
np
.
array
(
outputs
)
scores
=
scores
[
np
.
argsort
(
r
)]
return
scores
def
get_score
(
self
,
query
,
corpus
):
response
=
requests
.
post
(
self
.
url
,
json
=
{
"model"
:
self
.
model_name
,
"text_1"
:
query
,
"text_2"
:
corpus
,
"truncate_prompt_tokens"
:
-
1
,
}).
json
()
return
response
[
'data'
][
0
][
"score"
]
class
RerankClientMtebEncoder
(
ScoreClientMtebEncoder
):
def
get_score
(
self
,
query
,
corpus
):
response
=
requests
.
post
(
self
.
url
,
json
=
{
"model"
:
self
.
model_name
,
"query"
:
query
,
"documents"
:
[
corpus
],
"truncate_prompt_tokens"
:
-
1
,
}).
json
()
return
response
[
'results'
][
0
][
"relevance_score"
]
def
run_mteb_embed_task
(
encoder
,
tasks
):
def
run_mteb_embed_task
(
encoder
,
tasks
):
tasks
=
mteb
.
get_tasks
(
tasks
=
tasks
)
tasks
=
mteb
.
get_tasks
(
tasks
=
tasks
)
evaluation
=
mteb
.
MTEB
(
tasks
=
tasks
)
evaluation
=
mteb
.
MTEB
(
tasks
=
tasks
)
results
=
evaluation
.
run
(
encoder
,
verbosity
=
0
,
output_folder
=
None
)
results
=
evaluation
.
run
(
encoder
,
verbosity
=
0
,
output_folder
=
None
,
encode_kwargs
=
{
"show_progress_bar"
:
False
,
},
)
main_score
=
results
[
0
].
scores
[
"test"
][
0
][
"main_score"
]
main_score
=
results
[
0
].
scores
[
"test"
][
0
][
"main_score"
]
return
main_score
return
main_score
def
run_mteb_embed_task_st
(
model_name
,
tasks
):
from
sentence_transformers
import
SentenceTransformer
model
=
SentenceTransformer
(
model_name
)
return
run_mteb_embed_task
(
model
,
tasks
)
def
mteb_test_embed_models
(
hf_runner
,
def
mteb_test_embed_models
(
hf_runner
,
vllm_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
model_info
:
EmbedModelInfo
,
...
@@ -118,3 +199,96 @@ def mteb_test_embed_models(hf_runner,
...
@@ -118,3 +199,96 @@ def mteb_test_embed_models(hf_runner,
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_EMBED_TOL
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_EMBED_TOL
)
def
run_mteb_rerank
(
cross_encoder
,
tasks
,
languages
):
with
tempfile
.
TemporaryDirectory
()
as
results_folder
:
bm25s
=
mteb
.
get_model
(
"bm25s"
)
tasks
=
mteb
.
get_tasks
(
tasks
=
tasks
,
languages
=
languages
)
subset
=
"default"
eval_splits
=
[
"test"
]
evaluation
=
mteb
.
MTEB
(
tasks
=
tasks
)
evaluation
.
run
(
bm25s
,
verbosity
=
0
,
eval_splits
=
eval_splits
,
save_predictions
=
True
,
output_folder
=
f
"
{
results_folder
}
/stage1"
,
encode_kwargs
=
{
"show_progress_bar"
:
False
},
)
results
=
evaluation
.
run
(
cross_encoder
,
verbosity
=
0
,
eval_splits
=
eval_splits
,
top_k
=
10
,
save_predictions
=
True
,
output_folder
=
f
"
{
results_folder
}
/stage2"
,
previous_results
=
f
"
{
results_folder
}
/stage1/NFCorpus_
{
subset
}
_predictions.json"
,
encode_kwargs
=
{
"show_progress_bar"
:
False
},
)
main_score
=
results
[
0
].
scores
[
"test"
][
0
][
"main_score"
]
return
main_score
def
mteb_test_rerank_models
(
hf_runner
,
vllm_runner
,
model_info
:
RerankModelInfo
,
vllm_extra_kwargs
=
None
,
hf_model_callback
=
None
):
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
with
vllm_runner
(
model_info
.
name
,
task
=
"score"
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
model
.
llm_engine
.
model_config
.
architectures
)
vllm_main_score
=
run_mteb_rerank
(
VllmMtebEncoder
(
vllm_model
),
tasks
=
MTEB_RERANK_TASKS
,
languages
=
MTEB_RERANK_LANGS
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
with
hf_runner
(
model_info
.
name
,
is_cross_encoder
=
True
,
dtype
=
"float32"
)
as
hf_model
:
original_predict
=
hf_model
.
predict
def
_predict
(
sentences
:
list
[
tuple
[
str
,
str
,
Optional
[
str
]]],
# query, corpus, prompt
*
args
,
**
kwargs
,
):
# vllm and st both remove the prompt, fair comparison.
prompts
=
[(
s
[
0
],
s
[
1
])
for
s
in
sentences
]
return
original_predict
(
prompts
,
*
args
,
**
kwargs
,
batch_size
=
8
)
hf_model
.
predict
=
_predict
hf_model
.
original_predict
=
original_predict
if
hf_model_callback
is
not
None
:
hf_model_callback
(
hf_model
)
st_main_score
=
run_mteb_rerank
(
hf_model
,
tasks
=
MTEB_RERANK_TASKS
,
languages
=
MTEB_RERANK_LANGS
)
st_dtype
=
next
(
hf_model
.
model
.
model
.
parameters
()).
dtype
print
(
"VLLM:"
,
vllm_dtype
,
vllm_main_score
)
print
(
"SentenceTransformers:"
,
st_dtype
,
st_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_RERANK_TOL
)
tests/models/language/pooling/test_baai.py
View file @
f40f763f
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
pytest
from
.embed_utils
import
EmbedModelInfo
,
correctness_test_embed_models
from
...utils
import
EmbedModelInfo
,
RerankModelInfo
from
.mteb_utils
import
mteb_test_embed_models
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
,
mteb_test_rerank_models
MODELS
=
[
MODELS
=
[
########## BertModel
########## BertModel
...
@@ -57,6 +58,20 @@ MODELS = [
...
@@ -57,6 +58,20 @@ MODELS = [
enable_test
=
True
),
enable_test
=
True
),
]
]
RERANK_MODELS
=
[
########## XLMRobertaForSequenceClassification
RerankModelInfo
(
"BAAI/bge-reranker-base"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
enable_test
=
True
),
RerankModelInfo
(
"BAAI/bge-reranker-large"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
enable_test
=
False
),
RerankModelInfo
(
"BAAI/bge-reranker-v2-m3"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
dtype
=
"float32"
,
enable_test
=
False
)
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_embed_models_mteb
(
hf_runner
,
vllm_runner
,
def
test_embed_models_mteb
(
hf_runner
,
vllm_runner
,
...
@@ -70,3 +85,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
...
@@ -70,3 +85,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
example_prompts
)
->
None
:
example_prompts
)
->
None
:
correctness_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
,
correctness_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
,
example_prompts
)
example_prompts
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
RERANK_MODELS
)
def
test_rerank_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
RerankModelInfo
)
->
None
:
mteb_test_rerank_models
(
hf_runner
,
vllm_runner
,
model_info
)
tests/models/language/pooling/test_cross_encoder.py
0 → 100644
View file @
f40f763f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"cross-encoder/ms-marco-TinyBERT-L-2-v2"
,
architecture
=
"BertForSequenceClassification"
),
RerankModelInfo
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
architecture
=
"Qwen3ForSequenceClassification"
)
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
RERANK_MODELS
)
def
test_rerank_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
RerankModelInfo
)
->
None
:
mteb_test_rerank_models
(
hf_runner
,
vllm_runner
,
model_info
)
tests/models/language/pooling/test_jina.py
View file @
f40f763f
...
@@ -6,28 +6,10 @@ import pytest
...
@@ -6,28 +6,10 @@ import pytest
from
vllm
import
PoolingParams
from
vllm
import
PoolingParams
from
.embed_utils
import
(
EmbedModelInfo
,
check_embeddings_close
,
from
...utils
import
EmbedModelInfo
,
RerankModelInfo
from
.embed_utils
import
(
check_embeddings_close
,
correctness_test_embed_models
,
matryoshka_fy
)
correctness_test_embed_models
,
matryoshka_fy
)
from
.mteb_utils
import
mteb_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
,
mteb_test_rerank_models
SCORING_MODELS
=
[
"jinaai/jina-reranker-v2-base-multilingual"
,
# Roberta
]
TEXTS_1
=
[
"Organic skincare products for sensitive skin"
]
TEXTS_2
=
[
"Organic skincare for sensitive skin with aloe vera and chamomile."
,
"New makeup trends focus on bold colors and innovative techniques"
,
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille"
,
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken"
,
# noqa: E501
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla"
,
# noqa: E501
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras"
,
# noqa: E501
"针对敏感肌专门设计的天然有机护肤产品"
,
"新的化妆趋势注重鲜艳的颜色和创新的技巧"
,
"敏感肌のために特別に設計された天然有機スキンケア製品"
,
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています"
,
]
EMBEDDING_MODELS
=
[
EMBEDDING_MODELS
=
[
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
...
@@ -35,47 +17,13 @@ EMBEDDING_MODELS = [
...
@@ -35,47 +17,13 @@ EMBEDDING_MODELS = [
is_matryoshka
=
True
)
is_matryoshka
=
True
)
]
]
RERANK_MODELS
=
[
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
SCORING_MODELS
)
RerankModelInfo
(
def
model_name
(
request
):
"jinaai/jina-reranker-v2-base-multilingual"
,
yield
request
.
param
architecture
=
"XLMRobertaForSequenceClassification"
,
dtype
=
"float32"
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
]
def
test_llm_1_to_1
(
vllm_runner
,
hf_runner
,
model_name
,
dtype
:
str
):
text_pair
=
[
TEXTS_1
[
0
],
TEXTS_2
[
0
]]
with
hf_runner
(
model_name
,
dtype
=
dtype
,
is_cross_encoder
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
predict
([
text_pair
]).
tolist
()
with
vllm_runner
(
model_name
,
task
=
"score"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
score
(
text_pair
[
0
],
text_pair
[
1
])
assert
len
(
vllm_outputs
)
==
1
assert
len
(
hf_outputs
)
==
1
assert
hf_outputs
[
0
]
==
pytest
.
approx
(
vllm_outputs
[
0
],
rel
=
0.01
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_llm_1_to_N
(
vllm_runner
,
hf_runner
,
model_name
,
dtype
:
str
):
text_pairs
=
[[
TEXTS_1
[
0
],
text
]
for
text
in
TEXTS_2
]
with
hf_runner
(
model_name
,
dtype
=
dtype
,
is_cross_encoder
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
predict
(
text_pairs
).
tolist
()
with
vllm_runner
(
model_name
,
task
=
"score"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
score
(
TEXTS_1
[
0
],
TEXTS_2
)
assert
len
(
vllm_outputs
)
==
10
assert
len
(
hf_outputs
)
==
10
assert
hf_outputs
[
0
]
==
pytest
.
approx
(
vllm_outputs
[
0
],
rel
=
0.01
)
assert
hf_outputs
[
1
]
==
pytest
.
approx
(
vllm_outputs
[
1
],
rel
=
0.01
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
EMBEDDING_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
EMBEDDING_MODELS
)
...
@@ -106,6 +54,12 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
...
@@ -106,6 +54,12 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
hf_model_callback
=
hf_model_callback
)
hf_model_callback
=
hf_model_callback
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
RERANK_MODELS
)
def
test_rerank_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
RerankModelInfo
)
->
None
:
mteb_test_rerank_models
(
hf_runner
,
vllm_runner
,
model_info
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
EMBEDDING_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
EMBEDDING_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dimensions"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"dimensions"
,
[
16
,
32
])
...
...
tests/models/language/pooling/test_qwen3_reranker.py
View file @
f40f763f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pytest
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
model_name
=
"Qwen/Qwen3-Reranker-4B"
text_1
=
"What is the capital of France?"
texts_2
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
,
]
def
vllm_reranker
(
model_name
):
import
pytest
from
vllm
import
LLM
import
torch
model
=
LLM
(
model
=
model_name
,
from
tests.conftest
import
HfRunner
task
=
"score"
,
hf_overrides
=
{
"architectures"
:
[
"Qwen3ForSequenceClassification"
],
"classifier_from_token"
:
[
"no"
,
"yes"
],
"is_original_qwen3_reranker"
:
True
,
},
dtype
=
"float32"
)
text_1
=
"What is the capital of France?"
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
texts_2
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
,
]
outputs
=
model
.
score
(
text_1
,
texts_2
)
RERANK_MODELS
=
[
RerankModelInfo
(
"Qwen/Qwen3-Reranker-0.6B"
,
architecture
=
"Qwen3ForSequenceClassification"
,
dtype
=
"float32"
,
enable_test
=
True
),
RerankModelInfo
(
"Qwen/Qwen3-Reranker-4B"
,
architecture
=
"Qwen3ForSequenceClassification"
,
dtype
=
"float32"
,
enable_test
=
False
)
]
return
[
output
.
outputs
.
score
for
output
in
outputs
]
class
Qwen3RerankerHfRunner
(
HfRunner
):
def
hf_reranker
(
model_name
):
def
__init__
(
self
,
import
torch
model_name
:
str
,
dtype
:
str
=
"auto"
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
super
().
__init__
(
model_name
,
dtype
,
auto_cls
=
AutoModelForCausalLM
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
'left'
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
()
padding_side
=
'left'
)
self
.
token_false_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"no"
)
token_false_id
=
tokenizer
.
convert_tokens_to_ids
(
"no"
)
self
.
token_true_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"yes"
)
token_true_id
=
tokenizer
.
convert_tokens_to_ids
(
"yes"
)
max_length
=
8192
def
predict
(
self
,
prompts
:
list
[
list
[
str
]],
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
process_inputs
(
pairs
):
def
process_inputs
(
pairs
):
inputs
=
tokenizer
(
pairs
,
inputs
=
self
.
tokenizer
(
pairs
,
padding
=
False
,
padding
=
False
,
truncation
=
'longest_first'
,
truncation
=
'longest_first'
,
return_attention_mask
=
False
,
return_attention_mask
=
False
)
max_length
=
max_length
)
for
i
,
ele
in
enumerate
(
inputs
[
'input_ids'
]):
for
i
,
ele
in
enumerate
(
inputs
[
'input_ids'
]):
inputs
[
'input_ids'
][
i
]
=
ele
inputs
[
'input_ids'
][
i
]
=
ele
inputs
=
tokenizer
.
pad
(
inputs
,
inputs
=
self
.
tokenizer
.
pad
(
inputs
,
padding
=
True
,
padding
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
)
max_length
=
max_length
)
for
key
in
inputs
:
for
key
in
inputs
:
inputs
[
key
]
=
inputs
[
key
].
to
(
model
.
device
)
inputs
[
key
]
=
inputs
[
key
].
to
(
self
.
model
.
device
)
return
inputs
return
inputs
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
compute_logits
(
inputs
,
**
kwargs
):
def
compute_logits
(
inputs
):
batch_scores
=
model
(
**
inputs
).
logits
[:,
-
1
,
:]
batch_scores
=
self
.
model
(
**
inputs
).
logits
[:,
-
1
,
:]
true_vector
=
batch_scores
[:,
token_true_id
]
true_vector
=
batch_scores
[:,
self
.
token_true_id
]
false_vector
=
batch_scores
[:,
token_false_id
]
false_vector
=
batch_scores
[:,
self
.
token_false_id
]
batch_scores
=
torch
.
stack
([
false_vector
,
true_vector
],
dim
=
1
)
batch_scores
=
torch
.
stack
([
false_vector
,
true_vector
],
dim
=
1
)
batch_scores
=
torch
.
nn
.
functional
.
log_softmax
(
batch_scores
,
dim
=
1
)
batch_scores
=
torch
.
nn
.
functional
.
log_softmax
(
batch_scores
,
dim
=
1
)
scores
=
batch_scores
[:,
1
].
exp
()
.
tolist
()
scores
=
batch_scores
[:,
1
].
exp
()
return
scores
return
scores
pairs
=
[(
text_1
,
texts_2
[
0
]),
(
text_1
,
texts_2
[
1
])]
scores
=
[]
inputs
=
process_inputs
(
pairs
)
for
prompt
in
prompts
:
scores
=
compute_logits
(
inputs
)
inputs
=
process_inputs
([
prompt
])
score
=
compute_logits
(
inputs
)
scores
.
append
(
score
[
0
].
item
())
return
torch
.
Tensor
(
scores
)
return
scores
@
pytest
.
mark
.
parametrize
(
"model_info"
,
RERANK_MODELS
)
def
test_rerank_models_mteb
(
vllm_runner
,
model_info
:
RerankModelInfo
)
->
None
:
assert
model_info
.
architecture
==
"Qwen3ForSequenceClassification"
vllm_extra_kwargs
:
dict
[
str
,
Any
]
=
{
"hf_overrides"
:
{
"architectures"
:
[
"Qwen3ForSequenceClassification"
],
"classifier_from_token"
:
[
"no"
,
"yes"
],
"is_original_qwen3_reranker"
:
True
,
}
}
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
model_name
])
if
model_info
.
name
==
"Qwen/Qwen3-Reranker-4B"
:
def
test_model
(
model_name
):
vllm_extra_kwargs
[
"max_num_seqs"
]
=
1
hf_outputs
=
hf_reranker
(
model_name
)
vllm_outputs
=
vllm_reranker
(
model_name
)
assert
hf_outputs
[
0
]
==
pytest
.
approx
(
vllm_outputs
[
0
],
rel
=
0.01
)
mteb_test_rerank_models
(
Qwen3RerankerHfRunner
,
vllm_runner
,
model_info
,
assert
hf_outputs
[
1
]
==
pytest
.
approx
(
vllm_outputs
[
1
],
rel
=
0.01
)
vllm_extra_kwargs
)
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
deleted
100644 → 0
View file @
26bc46ef
# SPDX-License-Identifier: Apache-2.0
import
pytest
model_name
=
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
text_1
=
"What is the capital of France?"
texts_2
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
,
]
def
vllm_reranker
(
model_name
):
from
vllm
import
LLM
model
=
LLM
(
model
=
model_name
,
task
=
"score"
)
outputs
=
model
.
score
(
text_1
,
texts_2
)
return
[
output
.
outputs
.
score
for
output
in
outputs
]
def
hf_reranker
(
model_name
):
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
'left'
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
()
token_false_id
=
tokenizer
.
convert_tokens_to_ids
(
"no"
)
token_true_id
=
tokenizer
.
convert_tokens_to_ids
(
"yes"
)
max_length
=
8192
def
process_inputs
(
pairs
):
inputs
=
tokenizer
(
pairs
,
padding
=
False
,
truncation
=
'longest_first'
,
return_attention_mask
=
False
,
max_length
=
max_length
)
for
i
,
ele
in
enumerate
(
inputs
[
'input_ids'
]):
inputs
[
'input_ids'
][
i
]
=
ele
inputs
=
tokenizer
.
pad
(
inputs
,
padding
=
True
,
return_tensors
=
"pt"
,
max_length
=
max_length
)
for
key
in
inputs
:
inputs
[
key
]
=
inputs
[
key
].
to
(
model
.
device
)
return
inputs
@
torch
.
no_grad
()
def
compute_logits
(
inputs
,
**
kwargs
):
batch_scores
=
model
(
**
inputs
).
logits
[:,
-
1
,
:]
true_vector
=
batch_scores
[:,
token_true_id
]
false_vector
=
batch_scores
[:,
token_false_id
]
batch_scores
=
torch
.
stack
([
false_vector
,
true_vector
],
dim
=
1
)
batch_scores
=
torch
.
nn
.
functional
.
log_softmax
(
batch_scores
,
dim
=
1
)
scores
=
batch_scores
[:,
1
].
exp
().
tolist
()
return
scores
pairs
=
[(
text_1
,
texts_2
[
0
]),
(
text_1
,
texts_2
[
1
])]
inputs
=
process_inputs
(
pairs
)
scores
=
compute_logits
(
inputs
)
return
scores
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
model_name
])
def
test_model
(
model_name
):
hf_outputs
=
hf_reranker
(
model_name
)
vllm_outputs
=
vllm_reranker
(
model_name
)
assert
hf_outputs
[
0
]
==
pytest
.
approx
(
vllm_outputs
[
0
],
rel
=
0.01
)
assert
hf_outputs
[
1
]
==
pytest
.
approx
(
vllm_outputs
[
1
],
rel
=
0.01
)
tests/models/utils.py
View file @
f40f763f
...
@@ -336,3 +336,10 @@ class EmbedModelInfo(NamedTuple):
...
@@ -336,3 +336,10 @@ class EmbedModelInfo(NamedTuple):
architecture
:
str
=
""
architecture
:
str
=
""
dtype
:
str
=
"auto"
dtype
:
str
=
"auto"
enable_test
:
bool
=
True
enable_test
:
bool
=
True
class
RerankModelInfo
(
NamedTuple
):
name
:
str
architecture
:
str
=
""
dtype
:
str
=
"auto"
enable_test
:
bool
=
True
vllm/model_executor/layers/pooler.py
View file @
f40f763f
...
@@ -156,7 +156,10 @@ class MeanPool(SimplePooler):
...
@@ -156,7 +156,10 @@ class MeanPool(SimplePooler):
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
torch
.
cat
([
start_indices
=
torch
.
cat
([
torch
.
tensor
([
0
],
device
=
hidden_states
.
device
),
torch
.
tensor
([
0
],
device
=
hidden_states
.
device
),
torch
.
cumsum
(
prompt_lens
[:
-
1
],
dim
=
0
)
torch
.
cumsum
(
prompt_lens
[:
-
1
],
dim
=
0
)
...
@@ -220,6 +223,13 @@ class PoolerHead(nn.Module):
...
@@ -220,6 +223,13 @@ class PoolerHead(nn.Module):
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
pooling_metadata
:
PoolingMetadata
):
# Using float32 in PoolerHead
if
isinstance
(
pooled_data
,
list
):
for
i
in
range
(
len
(
pooled_data
)):
pooled_data
[
i
]
=
pooled_data
[
i
].
to
(
torch
.
float32
)
else
:
pooled_data
=
pooled_data
.
to
(
torch
.
float32
)
dimensions_list
=
[
dimensions_list
=
[
pooling_param
.
dimensions
pooling_param
.
dimensions
for
_
,
pooling_param
in
pooling_metadata
.
seq_groups
for
_
,
pooling_param
in
pooling_metadata
.
seq_groups
...
...
vllm/model_executor/models/bert.py
View file @
f40f763f
...
@@ -414,16 +414,11 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -414,16 +414,11 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
return
hidden_states
def
pooler
(
def
pooler
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
f40f763f
...
@@ -432,12 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -432,12 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
else
:
else
:
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
hidden_states
=
self
.
encoder
(
positions
,
hidden_states
)
return
self
.
encoder
(
positions
,
hidden_states
)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment