Unverified Commit 9fc3e8aa authored by satyamk7054's avatar satyamk7054 Committed by GitHub
Browse files

Add support for Matryoshka embeddings (#126) (#11142)


Co-authored-by: default avatarSatyam Kumar <satyamk@linkedin.com>
parent c11b34d5
......@@ -18,6 +18,7 @@ Usage:
import asyncio
import logging
from typing import Optional
from transformers import AutoTokenizer
from util import (
......@@ -52,11 +53,14 @@ config.freeze_gc = True # Enable GC freeze functionality
HTTP_URL = "http://localhost:30000/v1/embeddings"
# Embeddings API Config
EMBEDDINGS_MODEL_PATH = "/Qwen/Qwen3-Embedding-0.6B"
EMBEDDINGS_MODEL_PATH = "Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZE = [1] # Number of items per request (batch size)
# Configurable input token length
EMBEDDINGS_INPUT_TOKENS = 500 # Default token length
MATRYOSHKA_DIMENSIONS: Optional[int] = (
None # Set to None to disable matryoshka embeddings
)
# Load tokenizer once for embeddings text generation
print("Loading tokenizer for embeddings input generation...")
......@@ -85,6 +89,7 @@ def build_embeddings_request(index: int, item_count: int) -> tuple:
req = {
"input": input_data,
"model": EMBEDDINGS_MODEL_PATH,
"dimensions": MATRYOSHKA_DIMENSIONS,
}
return (index, req)
except Exception as e:
......@@ -94,7 +99,12 @@ def build_embeddings_request(index: int, item_count: int) -> tuple:
def validate_embeddings_response(response_data: dict) -> bool:
"""Validate embeddings API response."""
return "data" in response_data
return (
"data" in response_data
and len(response_data["data"][0]["embedding"]) == MATRYOSHKA_DIMENSIONS
if MATRYOSHKA_DIMENSIONS
else True
)
def build_warmup_embeddings_request() -> dict:
......@@ -102,6 +112,7 @@ def build_warmup_embeddings_request() -> dict:
return {
"input": EMBEDDINGS_INPUT_TEXT,
"model": EMBEDDINGS_MODEL_PATH,
"dimensions": MATRYOSHKA_DIMENSIONS,
}
......
......@@ -75,6 +75,45 @@ response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
```
## Matryoshka Embedding Example
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
### 1. Launch a Matryoshka‑capable model
If the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below:
```shell
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-Embedding-0.6B \
--is-embedding \
--host 0.0.0.0 \
--port 30000 \
--json-model-override-args '{"matryoshka_dimensions": [128, 256, 512, 1024, 1536]}'
```
1. Setting `"is_matryoshka": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`.
2. Omitting `dimensions` in a request returns the full vector.
### 2. Make requests with different output dimensions
```python
import requests
url = "http://127.0.0.1:30000"
# Request a truncated (Matryoshka) embedding by specifying a supported dimension.
payload = {
"model": "Qwen/Qwen3-Embedding-0.6B",
"input": "Explain diffusion models simply.",
"dimensions": 512 # change to 128 / 1024 / omit for full size
}
response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embedding:", response["data"][0]["embedding"])
```
## Supported Models
| Model Family | Example Model | Chat Template | Description |
......
......@@ -205,6 +205,14 @@ class ModelConfig:
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
# matryoshka embeddings
self.matryoshka_dimensions = getattr(
self.hf_config, "matryoshka_dimensions", None
)
self.is_matryoshka = self.matryoshka_dimensions or getattr(
self.hf_config, "is_matryoshka", False
)
@staticmethod
def from_server_args(
server_args: ServerArgs,
......
......@@ -312,6 +312,7 @@ class Engine(EngineBase):
image_data: Optional[MultimodalDataInputFormat] = None,
audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
......@@ -322,6 +323,7 @@ class Engine(EngineBase):
image_data=image_data,
audio_data=audio_data,
video_data=video_data,
dimensions=dimensions,
)
generator = self.tokenizer_manager.generate_request(obj, None)
ret = self.loop.run_until_complete(generator.__anext__())
......@@ -333,6 +335,7 @@ class Engine(EngineBase):
image_data: Optional[MultimodalDataInputFormat] = None,
audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
) -> Dict:
"""
Asynchronous version of encode method.
......@@ -345,6 +348,7 @@ class Engine(EngineBase):
image_data=image_data,
audio_data=audio_data,
video_data=video_data,
dimensions=dimensions,
)
generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__()
......
......@@ -126,6 +126,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
**prompt_kwargs,
rid=request.rid,
priority=request.priority,
dimensions=request.dimensions,
)
return adapted_request, request
......
......@@ -20,7 +20,9 @@ class PoolingType(IntEnum):
@dataclass
class EmbeddingPoolerOutput:
embeddings: torch.Tensor
# Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
# due to different per-request matryoshka dim truncation
embeddings: torch.Tensor | list[torch.Tensor]
class Pooler(nn.Module):
......@@ -42,6 +44,7 @@ class Pooler(nn.Module):
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
......@@ -53,8 +56,24 @@ class Pooler(nn.Module):
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if forward_batch.dimensions is not None:
all_same_dimensions = len(set(forward_batch.dimensions)) == 1
if all_same_dimensions:
pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
else:
pooled_data = [
tensor[..., :dim]
for tensor, dim in zip(pooled_data, forward_batch.dimensions)
]
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.normalize(tensor, p=2, dim=-1)
for tensor in pooled_data
]
else:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
return EmbeddingPoolerOutput(embeddings=pooled_data)
......
......@@ -695,6 +695,9 @@ class EmbeddingReqInput(BaseReq):
# tracing context
trace_context: Optional[Dict] = None
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
dimensions: Optional[int] = None
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
......@@ -771,6 +774,7 @@ class EmbeddingReqInput(BaseReq):
video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc,
)
......@@ -791,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
data_parallel_rank: Optional[int] = None
# Priority for the request
priority: Optional[int] = None
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
dimensions: Optional[int] = None
@dataclass
......
......@@ -442,6 +442,7 @@ class Req:
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None,
dimensions: Optional[int] = None,
http_worker_ipc: Optional[str] = None,
):
# Input and output info
......@@ -650,6 +651,9 @@ class Req:
self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1
# For Matryoshka embeddings
self.dimensions = dimensions
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
......@@ -1014,6 +1018,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# For split prefill
split_index: int = 0
split_prefill_finished: bool = False
......@@ -1177,6 +1184,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs]
# For matryoshka embeddings
if self.model_config.is_matryoshka and any(
r.dimensions is not None for r in reqs
):
self.dimensions = [
r.dimensions if r.dimensions else self.model_config.hidden_size
for r in reqs
]
token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None
]
......@@ -1765,6 +1781,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
is_prefill_only=self.is_prefill_only,
dimensions=self.dimensions,
)
def copy(self):
......@@ -1873,5 +1890,8 @@ class ModelWorkerBatch:
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
......@@ -1475,6 +1475,7 @@ class Scheduler(
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
dimensions=recv_req.dimensions,
http_worker_ipc=recv_req.http_worker_ipc,
)
req.tokenizer = self.tokenizer
......
......@@ -203,7 +203,10 @@ class SchedulerOutputProcessorMixin:
i
].item()
else:
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.tolist()
else:
embeddings = [tensor.tolist() for tensor in embeddings]
# Check finish conditions
for i, req in enumerate(batch.reqs):
......
......@@ -666,6 +666,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
raise ValueError(error_msg)
# Matryoshka embeddings validations
if isinstance(obj, EmbeddingReqInput):
self._validate_for_matryoshka_dim(obj)
if isinstance(obj, GenerateReqInput):
if (
obj.return_hidden_states
......@@ -684,6 +688,34 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"Please set `--enable-custom-logit-processor` to enable this feature."
)
def _validate_for_matryoshka_dim(self, obj: EmbeddingReqInput) -> None:
"""Validate the request for Matryoshka dim if it has the field set."""
if obj.dimensions is None:
return
if not self.model_config.is_matryoshka:
raise ValueError(
f"Model '{self.model_config.model_path}' does not support matryoshka representation, "
f"changing output dimensions will lead to poor results."
)
if obj.dimensions < 1:
raise ValueError("Requested dimensions must be greater than 0")
if (
self.model_config.matryoshka_dimensions
and obj.dimensions not in self.model_config.matryoshka_dimensions
):
raise ValueError(
f"Model '{self.model_config.model_path}' only supports {self.model_config.matryoshka_dimensions} matryoshka dimensions, "
f"using other output dimensions will lead to poor results."
)
if obj.dimensions > self.model_config.hidden_size:
raise ValueError(
f"Provided dimensions are greater than max embedding dimension: {self.model_config.hidden_size}"
)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
) -> None:
......@@ -752,6 +784,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
sampling_params,
rid=obj.rid,
priority=obj.priority,
dimensions=obj.dimensions,
http_worker_ipc=obj.http_worker_ipc,
)
......
......@@ -320,6 +320,9 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
@classmethod
def init_new(
cls,
......@@ -361,6 +364,7 @@ class ForwardBatch:
input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index,
dimensions=batch.dimensions,
)
device = model_runner.device
......
......@@ -12,10 +12,11 @@
# limitations under the License.
# ==============================================================================
import json
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -89,7 +90,9 @@ def get_token_ids_logprobs(logits, token_ids):
return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
def _get_sentence_transformer_embedding_model(
model_path, torch_dtype, matryoshka_dim: Optional[int] = None
):
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model
......@@ -97,6 +100,7 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
truncate_dim=matryoshka_dim,
)
else: # if no pre-trained sentence-transformers model
from sentence_transformers import models
......@@ -106,7 +110,9 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
word_embedding_model.get_word_embedding_dimension(),
pooling_mode="lasttoken",
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
model = SentenceTransformer(
modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
)
return model.cuda()
......@@ -135,6 +141,7 @@ class HFRunner:
output_str_only: bool = False,
trust_remote_code: bool = False,
patch_model_do_sample_false: bool = False,
matryoshka_dim: Optional[int] = None,
):
self.model_type = model_type
self.output_str_only = output_str_only
......@@ -151,6 +158,7 @@ class HFRunner:
self.out_queue,
model_path,
torch_dtype,
matryoshka_dim,
),
)
self.model_proc.start()
......@@ -225,7 +233,14 @@ class HFRunner:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous()
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
def start_model_process(
self,
in_queue,
out_queue,
model_path,
torch_dtype,
matryoshka_dim: Optional[int] = None,
):
# Apply model-specific patches
monkey_patch_gemma2_sdpa()
......@@ -259,7 +274,7 @@ class HFRunner:
self.processor = AutoProcessor.from_pretrained(model_path)
else:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
model_path, torch_dtype, matryoshka_dim=matryoshka_dim
)
elif self.model_type == "reward" or self.model_type == "cross_encoder":
from transformers import AutoModelForSequenceClassification
......@@ -519,6 +534,7 @@ class SRTRunner:
lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None,
max_loaded_loras: Optional[int] = None,
json_model_override_args: Optional[dict[str, Any]] = None,
lora_eviction_policy: str = "lru",
):
self.model_type = model_type
......@@ -566,6 +582,11 @@ class SRTRunner:
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
max_loaded_loras=max_loaded_loras,
json_model_override_args=(
json.dumps(json_model_override_args)
if json_model_override_args
else "{}"
),
lora_eviction_policy=lora_eviction_policy,
**spec_kwargs,
)
......@@ -594,6 +615,7 @@ class SRTRunner:
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
dimensions: Optional[int] = None,
):
if self.is_generation:
return self.forward_generation_raw(
......@@ -607,7 +629,9 @@ class SRTRunner:
)
else:
if self.model_type == "embedding":
response = self.engine.encode(prompt=prompts, image_data=image_data)
response = self.engine.encode(
prompt=prompts, image_data=image_data, dimensions=dimensions
)
if isinstance(response, list):
logits = [x["embedding"] for x in response]
else:
......
......@@ -15,6 +15,7 @@
import multiprocessing as mp
import random
import unittest
from typing import Optional
import torch
from transformers import AutoConfig, AutoTokenizer
......@@ -69,6 +70,7 @@ class TestEmbeddingModels(CustomTestCase):
tp_size,
torch_dtype,
prefill_tolerance,
matryoshka_dim: Optional[int] = None,
) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path)
......@@ -76,6 +78,7 @@ class TestEmbeddingModels(CustomTestCase):
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
matryoshka_dim=matryoshka_dim,
) as hf_runner:
hf_outputs = hf_runner.forward(truncated_prompts)
......@@ -86,8 +89,13 @@ class TestEmbeddingModels(CustomTestCase):
torch_dtype=torch_dtype,
model_type="embedding",
attention_backend=attention_backend,
json_model_override_args=(
{"matryoshka_dimensions": [matryoshka_dim]} if matryoshka_dim else None
),
) as srt_runner:
srt_outputs = srt_runner.forward(truncated_prompts)
srt_outputs = srt_runner.forward(
truncated_prompts, dimensions=matryoshka_dim
)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
......@@ -113,6 +121,25 @@ class TestEmbeddingModels(CustomTestCase):
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
)
def test_matryoshka_embedding(self):
models_to_test = [
model
for model in MODELS
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == model[0]
]
assert len(models_to_test) == 1
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
prefill_tolerance,
matryoshka_dim=128,
)
if __name__ == "__main__":
unittest.main()
import json
import os
import unittest
import numpy as np
import openai
from sglang.srt.utils import kill_process_tree
......@@ -92,6 +95,105 @@ class TestOpenAIEmbedding(CustomTestCase):
# check the status code
self.assertEqual(cm.exception.status_code, 400)
def test_embedding_with_dimensions_parameter(self):
"""Test that non-Matryoshka models reject dimensions parameter."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test that specifying dimensions fails for non-Matryoshka models
with self.assertRaises(openai.BadRequestError) as cm:
client.embeddings.create(
model=self.model, input="Hello world", dimensions=512
)
self.assertEqual(cm.exception.status_code, 400)
class TestMatryoshkaEmbeddingModel(CustomTestCase):
"""Test class for Model that supports Matryoshka embedding functionality, using OpenAI API."""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.matryoshka_dims = [128, 256, 512, 768, 1024]
# Configure embedding-specific args with Matryoshka support via json_model_override_args
matryoshka_config = {
"is_matryoshka": True,
"matryoshka_dimensions": cls.matryoshka_dims,
}
other_args = [
"--is-embedding",
"--enable-metrics",
"--json-model-override-args",
json.dumps(matryoshka_config),
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
if hasattr(cls, "process"):
kill_process_tree(cls.process.pid)
def test_matryoshka_embedding_valid_dimensions(self):
"""Test Matryoshka embedding with valid dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test with various valid dimensions
for dimensions in self.matryoshka_dims:
with self.subTest(dimensions=dimensions):
response = client.embeddings.create(
model=self.model, input="Hello world", dimensions=dimensions
)
self.assertEqual(len(response.data), 1)
self.assertEqual(len(response.data[0].embedding), dimensions)
def test_matryoshka_embedding_batch_same_dimensions(self):
"""Test Matryoshka embedding with batch input and same dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=["Hello world", "Test text", "Another example"],
dimensions=256,
)
self.assertEqual(len(response.data), 3)
for embedding_data in response.data:
self.assertEqual(len(embedding_data.embedding), 256)
def test_matryoshka_embedding_no_dimensions(self):
"""Test embedding without specifying dimensions (should use full size)."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input="Hello world")
self.assertEqual(len(response.data), 1)
# Should return full embedding size when no dimensions specified
self.assertEqual(len(response.data[0].embedding), 1536)
def test_matryoshka_embedding_invalid_dimensions(self):
"""Test Matryoshka embedding with invalid dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
for dimensions in [100, 0, -1, 10000]:
with self.assertRaises(openai.BadRequestError) as cm:
client.embeddings.create(
model=self.model,
input="Hello world",
dimensions=dimensions,
)
self.assertEqual(cm.exception.status_code, 400)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment