Unverified Commit 67309a1c authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Using matryoshka_dimensions control the allowed output dimensions. (#16970)

parent b724afe3
...@@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model ...@@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model
### Manually enable Matryoshka Embeddings ### Manually enable Matryoshka Embeddings
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`. There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online). For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).
Here is an example to serve a model with Matryoshka Embeddings enabled. Here is an example to serve a model with Matryoshka Embeddings enabled.
```text ```text
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}' vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}'
``` ```
### Offline Inference ### Offline Inference
...@@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \ ...@@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \
"input": "Follow the white rabbit.", "input": "Follow the white rabbit.",
"model": "jinaai/jina-embeddings-v3", "model": "jinaai/jina-embeddings-v3",
"encoding_format": "float", "encoding_format": "float",
"dimensions": 1 "dimensions": 32
}' }'
``` ```
Expected output: Expected output:
```json ```json
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}} {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
``` ```
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>
...@@ -25,11 +25,11 @@ def main(): ...@@ -25,11 +25,11 @@ def main():
responses = client.embeddings.create( responses = client.embeddings.create(
input=["Follow the white rabbit."], input=["Follow the white rabbit."],
model=model, model=model,
dimensions=1, dimensions=32,
) )
for data in responses.data: for data in responses.data:
print(data.embedding) # List of float of len 1 print(data.embedding) # List of float of len 32
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -11,11 +11,12 @@ import requests ...@@ -11,11 +11,12 @@ import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.embedding.utils import check_embeddings_close from ...models.embedding.utils import correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -25,7 +26,7 @@ def server(): ...@@ -25,7 +26,7 @@ def server():
"embed", "embed",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", DTYPE,
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"512", "512",
...@@ -43,9 +44,17 @@ async def client(server): ...@@ -43,9 +44,17 @@ async def client(server):
yield async_client yield async_client
@pytest.fixture(scope="module")
def hf_model(hf_runner):
with hf_runner(MODEL_NAME, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
yield hf_model
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
input_texts = [ input_texts = [
"The chef prepared a delicious meal.", "The chef prepared a delicious meal.",
] ]
...@@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): ...@@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 11 assert embeddings.usage.prompt_tokens == 11
assert embeddings.usage.total_tokens == 11 assert embeddings.usage.total_tokens == 11
vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)
# test using token IDs # test using token IDs
input_tokens = [1, 1, 1, 1, 1] input_tokens = [1, 1, 1, 1, 1]
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
...@@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): ...@@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
# test list[str] # test list[str]
input_texts = [ input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.", "The cat sat on the mat.", "A feline was resting on a rug.",
...@@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): ...@@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.prompt_tokens == 33
assert embeddings.usage.total_tokens == 33 assert embeddings.usage.total_tokens == 33
vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)
# test list[list[int]] # test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]] [25, 32, 64, 77]]
...@@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, ...@@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding(client: openai.AsyncOpenAI, async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str): model_name: str):
input_texts = [ input_texts = [
"Hello my name is", "Hello my name is",
...@@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, ...@@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
model=model_name, model=model_name,
encoding_format="float") encoding_format="float")
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
correctness_test(hf_model, input_texts, float_data)
responses_base64 = await client.embeddings.create(input=input_texts, responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name, model=model_name,
...@@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, ...@@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
np.frombuffer(base64.b64decode(data.embedding), np.frombuffer(base64.b64decode(data.embedding),
dtype="float32").tolist()) dtype="float32").tolist())
check_embeddings_close( correctness_test(hf_model, input_texts, base64_data)
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float",
name_1="base64",
)
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
responses_default = await client.embeddings.create(input=input_texts, responses_default = await client.embeddings.create(input=input_texts,
model=model_name) model=model_name)
default_data = [d.embedding for d in responses_default.data] default_data = [d.embedding for d in responses_default.data]
correctness_test(hf_model, input_texts, default_data)
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float",
name_1="default",
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -3,45 +3,81 @@ ...@@ -3,45 +3,81 @@
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
""" """
from typing import Optional
import openai import openai
import pytest import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...models.embedding.utils import EmbedModelInfo from ...conftest import HfRunner
from ...models.embedding.utils import EmbedModelInfo, correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODELS = [ MODELS = [
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False), EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
matryoshka_dimensions=[256]),
] ]
input_texts = [ input_texts = [
"The chef prepared a delicious meal.", "The chef prepared a delicious meal.",
] * 3 ]
@pytest.mark.asyncio @pytest.fixture(scope="module", params=MODELS)
@pytest.mark.parametrize("model", MODELS) def model_info(request):
async def test_validating_dimensions(model: EmbedModelInfo): return request.param
@pytest.fixture(scope="module", params=["bfloat16"])
def dtype(request):
return request.param
@pytest.fixture(scope="module")
def server(model_info, dtype: str):
args = [ args = [
"--task", "--task",
"embed", "embed",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", dtype,
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"512", "512"
"--trust_remote_code"
] ]
with RemoteOpenAIServer(model.name, args) as remote_server:
client = remote_server.get_async_client()
async def make_request(dimensions): if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5":
# Manually enable Matryoshka Embeddings
args.extend([
"--trust_remote_code", "--hf_overrides",
'{"matryoshka_dimensions":[256]}'
])
with RemoteOpenAIServer(model_info.name, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def hf_model(hf_runner, model_info, dtype: str):
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
yield hf_model
@pytest.mark.asyncio
async def test_matryoshka(model_info: EmbedModelInfo,
server: RemoteOpenAIServer, hf_model: HfRunner):
client = server.get_async_client()
async def make_request_and_correctness_test(dimensions):
prompts = input_texts * 3
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
model=model.name, model=model_info.name,
input=input_texts, input=prompts,
dimensions=dimensions, dimensions=dimensions,
encoding_format="float", encoding_format="float",
) )
...@@ -58,18 +94,30 @@ async def test_validating_dimensions(model: EmbedModelInfo): ...@@ -58,18 +94,30 @@ async def test_validating_dimensions(model: EmbedModelInfo):
if dimensions is not None: if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions assert len(embeddings.data[0].embedding) == dimensions
if model.is_matryoshka: vllm_outputs = [d.embedding for d in embeddings.data]
for dimensions in [None, 16]: correctness_test(hf_model, prompts, vllm_outputs, dimensions)
await make_request(dimensions)
if model_info.is_matryoshka:
valid_dimensions: list[Optional[int]] = [None]
if model_info.matryoshka_dimensions is not None:
valid_dimensions += model_info.matryoshka_dimensions[:2]
for dimensions in valid_dimensions:
await make_request_and_correctness_test(dimensions)
invalid_dimensions: list[Optional[int]] = [-1]
if model_info.matryoshka_dimensions is not None:
assert 5 not in model_info.matryoshka_dimensions
invalid_dimensions.append(5)
for dimensions in invalid_dimensions:
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
for dimensions in [-1]: await make_request_and_correctness_test(dimensions)
await make_request(dimensions)
else: else:
for dimensions in [None]: for dimensions in [None]:
await make_request(dimensions) await make_request_and_correctness_test(dimensions)
with pytest.raises(openai.BadRequestError):
for dimensions in [-1, 16]: for dimensions in [-1, 16]:
await make_request(dimensions) with pytest.raises(openai.BadRequestError):
await make_request_and_correctness_test(dimensions)
...@@ -153,6 +153,16 @@ def test_matryoshka( ...@@ -153,6 +153,16 @@ def test_matryoshka(
with vllm_runner(model, task="embed", dtype=dtype, with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model: max_model_len=None) as vllm_model:
matryoshka_dimensions = (
vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
assert matryoshka_dimensions is not None
if dimensions not in matryoshka_dimensions:
with pytest.raises(ValueError):
vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
else:
vllm_outputs = vllm_model.encode( vllm_outputs = vllm_model.encode(
example_prompts, example_prompts,
pooling_params=PoolingParams(dimensions=dimensions)) pooling_params=PoolingParams(dimensions=dimensions))
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple from typing import NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions): ...@@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions):
class EmbedModelInfo(NamedTuple): class EmbedModelInfo(NamedTuple):
name: str name: str
is_matryoshka: bool is_matryoshka: bool
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = "" architecture: str = ""
enable_test: bool = True enable_test: bool = True
def correctness_test(hf_model,
inputs,
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
...@@ -1248,6 +1248,10 @@ class ModelConfig: ...@@ -1248,6 +1248,10 @@ class ModelConfig:
return (hasattr(self.hf_config, "matryoshka_dimensions") return (hasattr(self.hf_config, "matryoshka_dimensions")
or getattr(self.hf_config, "is_matryoshka", False)) or getattr(self.hf_config, "is_matryoshka", False))
@property
def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None)
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
......
...@@ -35,7 +35,16 @@ class PoolingParams( ...@@ -35,7 +35,16 @@ class PoolingParams(
f'Model "{model_config.served_model_name}" does not ' f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, ' f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.') f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
mds = model_config.matryoshka_dimensions
if mds is not None:
if self.dimensions not in mds:
raise ValueError(
f'Model "{model_config.served_model_name}" '
f'only supports {str(mds)} matryoshka dimensions, '
f'use other output dimensions will '
f'lead to poor results.')
elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str: def __repr__(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment