Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fbf722c6
Unverified
Commit
fbf722c6
authored
Apr 12, 2025
by
wang.yuqi
Committed by
GitHub
Apr 11, 2025
Browse files
[Frontend] support matryoshka representation / support embedding API dimensions (#16331)
parent
e92d7085
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
253 additions
and
22 deletions
+253
-22
examples/offline_inference/embed_matryoshka_fy.py
examples/offline_inference/embed_matryoshka_fy.py
+48
-0
tests/conftest.py
tests/conftest.py
+8
-8
tests/entrypoints/openai/test_embedding_dimensions.py
tests/entrypoints/openai/test_embedding_dimensions.py
+82
-0
tests/models/embedding/language/test_jina.py
tests/models/embedding/language/test_jina.py
+39
-1
tests/models/embedding/utils.py
tests/models/embedding/utils.py
+7
-0
vllm/config.py
vllm/config.py
+9
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+10
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-2
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-5
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+18
-4
vllm/pooling_params.py
vllm/pooling_params.py
+21
-2
No files found.
examples/offline_inference/embed_matryoshka_fy.py
0 → 100644
View file @
fbf722c6
# SPDX-License-Identifier: Apache-2.0
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
,
PoolingParams
from
vllm.utils
import
FlexibleArgumentParser
def
main
(
args
:
Namespace
):
# Sample prompts.
prompts
=
[
"Follow the white rabbit."
,
# English
"Sigue al conejo blanco."
,
# Spanish
"Suis le lapin blanc."
,
# French
"跟着白兔走。"
,
# Chinese
"اتبع الأرنب الأبيض."
,
# Arabic
"Folge dem weißen Kaninchen."
,
# German
]
# Create an LLM.
# You should pass task="embed" for embedding models
model
=
LLM
(
**
vars
(
args
))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs
=
model
.
embed
(
prompts
,
pooling_params
=
PoolingParams
(
dimensions
=
32
))
# Print the outputs.
print
(
"
\n
Generated Outputs:"
)
print
(
"-"
*
60
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
embeds
=
output
.
outputs
.
embedding
embeds_trimmed
=
((
str
(
embeds
[:
16
])[:
-
1
]
+
", ...]"
)
if
len
(
embeds
)
>
16
else
embeds
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
"
f
"Embeddings:
{
embeds_trimmed
}
"
f
"(size=
{
len
(
embeds
)
}
)"
)
print
(
"-"
*
60
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
EngineArgs
.
add_cli_args
(
parser
)
# Set example specific arguments
parser
.
set_defaults
(
model
=
"jinaai/jina-embeddings-v3"
,
task
=
"embed"
,
trust_remote_code
=
True
)
args
=
parser
.
parse_args
()
main
(
args
)
tests/conftest.py
View file @
fbf722c6
...
@@ -960,19 +960,19 @@ class VllmRunner:
...
@@ -960,19 +960,19 @@ class VllmRunner:
req_outputs
=
self
.
model
.
classify
(
prompts
)
req_outputs
=
self
.
model
.
classify
(
prompts
)
return
[
req_output
.
outputs
.
probs
for
req_output
in
req_outputs
]
return
[
req_output
.
outputs
.
probs
for
req_output
in
req_outputs
]
def
encode
(
def
encode
(
self
,
self
,
prompts
:
list
[
str
]
,
prompts
:
list
[
str
]
,
images
:
Optional
[
PromptImageInput
]
=
None
,
image
s
:
Optional
[
Prompt
Image
Input
]
=
None
,
video
s
:
Optional
[
Prompt
Video
Input
]
=
None
,
vide
os
:
Optional
[
Prompt
Vide
oInput
]
=
None
,
audi
os
:
Optional
[
Prompt
Audi
oInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
*
args
,
)
->
list
[
list
[
float
]]:
**
kwargs
)
->
list
[
list
[
float
]]:
inputs
=
self
.
get_inputs
(
prompts
,
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
images
=
images
,
videos
=
videos
,
videos
=
videos
,
audios
=
audios
)
audios
=
audios
)
req_outputs
=
self
.
model
.
embed
(
inputs
)
req_outputs
=
self
.
model
.
embed
(
inputs
,
*
args
,
**
kwargs
)
return
[
req_output
.
outputs
.
embedding
for
req_output
in
req_outputs
]
return
[
req_output
.
outputs
.
embedding
for
req_output
in
req_outputs
]
def
score
(
def
score
(
...
...
tests/entrypoints/openai/test_embedding_dimensions.py
0 → 100644
View file @
fbf722c6
# SPDX-License-Identifier: Apache-2.0
"""
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
"""
from
typing
import
NamedTuple
import
openai
import
pytest
from
vllm.entrypoints.openai.protocol
import
EmbeddingResponse
from
...utils
import
RemoteOpenAIServer
class
ModelInfo
(
NamedTuple
):
name
:
str
is_matryoshka
:
bool
MODELS
=
[
ModelInfo
(
name
=
"BAAI/bge-m3"
,
is_matryoshka
=
False
),
ModelInfo
(
name
=
"jinaai/jina-embeddings-v3"
,
is_matryoshka
=
True
),
]
input_texts
=
[
"The chef prepared a delicious meal."
,
]
*
3
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
async
def
test_validating_dimensions
(
model
:
ModelInfo
):
args
=
[
"--task"
,
"embed"
,
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--enforce-eager"
,
"--max-model-len"
,
"512"
,
"--trust_remote_code"
]
with
RemoteOpenAIServer
(
model
.
name
,
args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
async
def
make_request
(
dimensions
):
embedding_response
=
await
client
.
embeddings
.
create
(
model
=
model
.
name
,
input
=
input_texts
,
dimensions
=
dimensions
,
encoding_format
=
"float"
,
)
embeddings
=
EmbeddingResponse
.
model_validate
(
embedding_response
.
model_dump
(
mode
=
"json"
))
assert
embeddings
.
id
is
not
None
assert
len
(
embeddings
.
data
)
==
3
assert
len
(
embeddings
.
data
[
0
].
embedding
)
>
0
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
>
0
assert
embeddings
.
usage
.
total_tokens
>
0
if
dimensions
is
not
None
:
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
dimensions
if
model
.
is_matryoshka
:
for
dimensions
in
[
None
,
16
]:
await
make_request
(
dimensions
)
with
pytest
.
raises
(
openai
.
BadRequestError
):
for
dimensions
in
[
-
1
]:
await
make_request
(
dimensions
)
else
:
for
dimensions
in
[
None
]:
await
make_request
(
dimensions
)
with
pytest
.
raises
(
openai
.
BadRequestError
):
for
dimensions
in
[
-
1
,
16
]:
await
make_request
(
dimensions
)
tests/models/embedding/language/test_jina.py
View file @
fbf722c6
...
@@ -8,7 +8,8 @@ import math
...
@@ -8,7 +8,8 @@ import math
import
pytest
import
pytest
from
tests.models.embedding.utils
import
check_embeddings_close
from
tests.models.embedding.utils
import
check_embeddings_close
,
matryoshka_fy
from
vllm
import
PoolingParams
SCORING_MODELS
=
[
SCORING_MODELS
=
[
"jinaai/jina-reranker-v2-base-multilingual"
,
# Roberta
"jinaai/jina-reranker-v2-base-multilingual"
,
# Roberta
...
@@ -126,3 +127,40 @@ def test_embeddings(
...
@@ -126,3 +127,40 @@ def test_embeddings(
name_1
=
"vllm"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
tol
=
1e-2
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
EMBEDDING_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dimensions"
,
[
16
,
32
])
def
test_matryoshka
(
hf_runner
,
vllm_runner
,
model
,
dtype
:
str
,
dimensions
:
int
,
monkeypatch
,
)
->
None
:
example_prompts
=
EMBEDDING_PROMPTS
with
hf_runner
(
model
,
dtype
=
dtype
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
,
task
=
"text-matching"
)
hf_outputs
=
matryoshka_fy
(
hf_outputs
,
dimensions
)
with
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
,
pooling_params
=
PoolingParams
(
dimensions
=
dimensions
))
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
tests/models/embedding/utils.py
View file @
fbf722c6
...
@@ -30,3 +30,10 @@ def check_embeddings_close(
...
@@ -30,3 +30,10 @@ def check_embeddings_close(
f
"
\n
{
name_1
}
:
\t
{
embeddings_1
[:
16
]
!
r
}
"
)
f
"
\n
{
name_1
}
:
\t
{
embeddings_1
[:
16
]
!
r
}
"
)
assert
sim
>=
1
-
tol
,
fail_msg
assert
sim
>=
1
-
tol
,
fail_msg
def
matryoshka_fy
(
tensor
,
dimensions
):
tensor
=
torch
.
tensor
(
tensor
)
tensor
=
tensor
[...,
:
dimensions
]
tensor
=
F
.
normalize
(
tensor
,
p
=
2
,
dim
=
1
)
return
tensor
vllm/config.py
View file @
fbf722c6
...
@@ -583,6 +583,15 @@ class ModelConfig:
...
@@ -583,6 +583,15 @@ class ModelConfig:
if
getattr
(
user_config
,
k
)
is
None
:
if
getattr
(
user_config
,
k
)
is
None
:
setattr
(
user_config
,
k
,
v
)
setattr
(
user_config
,
k
,
v
)
if
self
.
is_matryoshka
:
if
user_config
.
normalize
is
None
:
user_config
.
normalize
=
True
elif
not
user_config
.
normalize
:
raise
ValueError
(
"`normalize` must be enabled (set to True) "
"for models that are compatible with "
"Matryoshka Representation."
)
return
user_config
return
user_config
return
None
return
None
...
...
vllm/entrypoints/llm.py
View file @
fbf722c6
...
@@ -921,6 +921,11 @@ class LLM:
...
@@ -921,6 +921,11 @@ class LLM:
if
pooling_params
is
None
:
if
pooling_params
is
None
:
# Use default pooling params.
# Use default pooling params.
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
self
.
llm_engine
.
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
self
.
llm_engine
.
model_config
)
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
prompts
=
parsed_prompts
,
...
@@ -939,6 +944,8 @@ class LLM:
...
@@ -939,6 +944,8 @@ class LLM:
/
,
/
,
*
,
*
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
list
[
EmbeddingRequestOutput
]:
)
->
list
[
EmbeddingRequestOutput
]:
...
@@ -953,6 +960,8 @@ class LLM:
...
@@ -953,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
...
@@ -968,6 +977,7 @@ class LLM:
...
@@ -968,6 +977,7 @@ class LLM:
items
=
self
.
encode
(
prompts
,
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
fbf722c6
...
@@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
...
@@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
@@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
return
data
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
fbf722c6
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return
error_check_ret
return
error_check_ret
encoding_format
=
request
.
encoding_format
encoding_format
=
request
.
encoding_format
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
model_name
=
self
.
_get_model_name
(
request
.
model
)
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"embd-
{
self
.
_base_request_id
(
raw_request
)
}
"
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
"greater than max_model_len."
" Please, select a smaller truncation size."
)
" Please, select a smaller truncation size."
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/model_executor/layers/pooler.py
View file @
fbf722c6
...
@@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
...
@@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_outputs
=
[
self
.
build_output
(
data
)
for
data
in
pooled_data
]
pooled_outputs
=
[
self
.
build_output
(
data
)
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
return
PoolerOutput
(
outputs
=
pooled_outputs
)
...
@@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
...
@@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
self
.
normalize
=
normalize
self
.
normalize
=
normalize
self
.
softmax
=
softmax
self
.
softmax
=
softmax
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]):
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
dimensions_list
=
[
pooling_param
.
dimensions
for
_
,
pooling_param
in
pooling_metadata
.
seq_groups
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
pooled_data
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
if
self
.
normalize
:
if
self
.
normalize
:
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
pooled_data
=
[
F
.
normalize
(
data
,
p
=
2
,
dim
=
1
)
for
data
in
pooled_data
F
.
normalize
(
data
,
p
=
2
,
dim
=
-
1
)
for
data
in
pooled_data
]
]
else
:
else
:
pooled_data
=
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
pooled_data
=
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
-
1
)
if
self
.
softmax
:
if
self
.
softmax
:
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
...
...
vllm/pooling_params.py
View file @
fbf722c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
import
msgspec
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
class
PoolingParams
(
class
PoolingParams
(
msgspec
.
Struct
,
msgspec
.
Struct
,
...
@@ -12,14 +15,30 @@ class PoolingParams(
...
@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
additional_data: Any additional data needed for pooling.
"""
"""
dimensions
:
Optional
[
int
]
=
None
additional_data
:
Optional
[
Any
]
=
None
additional_data
:
Optional
[
Any
]
=
None
def
clone
(
self
)
->
"PoolingParams"
:
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
dimensions
is
not
None
:
if
not
model_config
.
is_matryoshka
:
raise
ValueError
(
f
'Model "
{
model_config
.
served_model_name
}
" does not '
f
'support matryoshka representation, '
f
'changing output dimensions will lead to poor results.'
)
if
self
.
dimensions
<
1
:
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
return
(
f
"PoolingParams("
f
"dimensions=
{
self
.
dimensions
}
, "
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment