Unverified Commit 9fc3e8aa authored by satyamk7054's avatar satyamk7054 Committed by GitHub
Browse files

Add support for Matryoshka embeddings (#126) (#11142)


Co-authored-by: default avatarSatyam Kumar <satyamk@linkedin.com>
parent c11b34d5
...@@ -18,6 +18,7 @@ Usage: ...@@ -18,6 +18,7 @@ Usage:
import asyncio import asyncio
import logging import logging
from typing import Optional
from transformers import AutoTokenizer from transformers import AutoTokenizer
from util import ( from util import (
...@@ -52,11 +53,14 @@ config.freeze_gc = True # Enable GC freeze functionality ...@@ -52,11 +53,14 @@ config.freeze_gc = True # Enable GC freeze functionality
HTTP_URL = "http://localhost:30000/v1/embeddings" HTTP_URL = "http://localhost:30000/v1/embeddings"
# Embeddings API Config # Embeddings API Config
EMBEDDINGS_MODEL_PATH = "/Qwen/Qwen3-Embedding-0.6B" EMBEDDINGS_MODEL_PATH = "Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZE = [1] # Number of items per request (batch size) BATCH_SIZE = [1] # Number of items per request (batch size)
# Configurable input token length # Configurable input token length
EMBEDDINGS_INPUT_TOKENS = 500 # Default token length EMBEDDINGS_INPUT_TOKENS = 500 # Default token length
MATRYOSHKA_DIMENSIONS: Optional[int] = (
None # Set to None to disable matryoshka embeddings
)
# Load tokenizer once for embeddings text generation # Load tokenizer once for embeddings text generation
print("Loading tokenizer for embeddings input generation...") print("Loading tokenizer for embeddings input generation...")
...@@ -85,6 +89,7 @@ def build_embeddings_request(index: int, item_count: int) -> tuple: ...@@ -85,6 +89,7 @@ def build_embeddings_request(index: int, item_count: int) -> tuple:
req = { req = {
"input": input_data, "input": input_data,
"model": EMBEDDINGS_MODEL_PATH, "model": EMBEDDINGS_MODEL_PATH,
"dimensions": MATRYOSHKA_DIMENSIONS,
} }
return (index, req) return (index, req)
except Exception as e: except Exception as e:
...@@ -94,7 +99,12 @@ def build_embeddings_request(index: int, item_count: int) -> tuple: ...@@ -94,7 +99,12 @@ def build_embeddings_request(index: int, item_count: int) -> tuple:
def validate_embeddings_response(response_data: dict) -> bool: def validate_embeddings_response(response_data: dict) -> bool:
"""Validate embeddings API response.""" """Validate embeddings API response."""
return "data" in response_data return (
"data" in response_data
and len(response_data["data"][0]["embedding"]) == MATRYOSHKA_DIMENSIONS
if MATRYOSHKA_DIMENSIONS
else True
)
def build_warmup_embeddings_request() -> dict: def build_warmup_embeddings_request() -> dict:
...@@ -102,6 +112,7 @@ def build_warmup_embeddings_request() -> dict: ...@@ -102,6 +112,7 @@ def build_warmup_embeddings_request() -> dict:
return { return {
"input": EMBEDDINGS_INPUT_TEXT, "input": EMBEDDINGS_INPUT_TEXT,
"model": EMBEDDINGS_MODEL_PATH, "model": EMBEDDINGS_MODEL_PATH,
"dimensions": MATRYOSHKA_DIMENSIONS,
} }
......
...@@ -75,6 +75,45 @@ response = requests.post(url + "/v1/embeddings", json=payload).json() ...@@ -75,6 +75,45 @@ response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embeddings:", [x.get("embedding") for x in response.get("data", [])]) print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
``` ```
## Matryoshka Embedding Example
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
### 1. Launch a Matryoshka‑capable model
If the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below:
```shell
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-Embedding-0.6B \
--is-embedding \
--host 0.0.0.0 \
--port 30000 \
--json-model-override-args '{"matryoshka_dimensions": [128, 256, 512, 1024, 1536]}'
```
1. Setting `"is_matryoshka": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`.
2. Omitting `dimensions` in a request returns the full vector.
### 2. Make requests with different output dimensions
```python
import requests
url = "http://127.0.0.1:30000"
# Request a truncated (Matryoshka) embedding by specifying a supported dimension.
payload = {
"model": "Qwen/Qwen3-Embedding-0.6B",
"input": "Explain diffusion models simply.",
"dimensions": 512 # change to 128 / 1024 / omit for full size
}
response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embedding:", response["data"][0]["embedding"])
```
## Supported Models ## Supported Models
| Model Family | Example Model | Chat Template | Description | | Model Family | Example Model | Chat Template | Description |
......
...@@ -205,6 +205,14 @@ class ModelConfig: ...@@ -205,6 +205,14 @@ class ModelConfig:
self.hf_config, "image_token_id", None self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None) ) or getattr(self.hf_config, "image_token_index", None)
# matryoshka embeddings
self.matryoshka_dimensions = getattr(
self.hf_config, "matryoshka_dimensions", None
)
self.is_matryoshka = self.matryoshka_dimensions or getattr(
self.hf_config, "is_matryoshka", False
)
@staticmethod @staticmethod
def from_server_args( def from_server_args(
server_args: ServerArgs, server_args: ServerArgs,
......
...@@ -312,6 +312,7 @@ class Engine(EngineBase): ...@@ -312,6 +312,7 @@ class Engine(EngineBase):
image_data: Optional[MultimodalDataInputFormat] = None, image_data: Optional[MultimodalDataInputFormat] = None,
audio_data: Optional[MultimodalDataInputFormat] = None, audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None, video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
) -> Dict: ) -> Dict:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
...@@ -322,6 +323,7 @@ class Engine(EngineBase): ...@@ -322,6 +323,7 @@ class Engine(EngineBase):
image_data=image_data, image_data=image_data,
audio_data=audio_data, audio_data=audio_data,
video_data=video_data, video_data=video_data,
dimensions=dimensions,
) )
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = self.loop.run_until_complete(generator.__anext__()) ret = self.loop.run_until_complete(generator.__anext__())
...@@ -333,6 +335,7 @@ class Engine(EngineBase): ...@@ -333,6 +335,7 @@ class Engine(EngineBase):
image_data: Optional[MultimodalDataInputFormat] = None, image_data: Optional[MultimodalDataInputFormat] = None,
audio_data: Optional[MultimodalDataInputFormat] = None, audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None, video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
) -> Dict: ) -> Dict:
""" """
Asynchronous version of encode method. Asynchronous version of encode method.
...@@ -345,6 +348,7 @@ class Engine(EngineBase): ...@@ -345,6 +348,7 @@ class Engine(EngineBase):
image_data=image_data, image_data=image_data,
audio_data=audio_data, audio_data=audio_data,
video_data=video_data, video_data=video_data,
dimensions=dimensions,
) )
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__() return await generator.__anext__()
......
...@@ -126,6 +126,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -126,6 +126,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
**prompt_kwargs, **prompt_kwargs,
rid=request.rid, rid=request.rid,
priority=request.priority, priority=request.priority,
dimensions=request.dimensions,
) )
return adapted_request, request return adapted_request, request
......
...@@ -20,7 +20,9 @@ class PoolingType(IntEnum): ...@@ -20,7 +20,9 @@ class PoolingType(IntEnum):
@dataclass @dataclass
class EmbeddingPoolerOutput: class EmbeddingPoolerOutput:
embeddings: torch.Tensor # Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
# due to different per-request matryoshka dim truncation
embeddings: torch.Tensor | list[torch.Tensor]
class Pooler(nn.Module): class Pooler(nn.Module):
...@@ -42,6 +44,7 @@ class Pooler(nn.Module): ...@@ -42,6 +44,7 @@ class Pooler(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST: if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices] pooled_data = hidden_states[last_token_indices]
...@@ -53,8 +56,24 @@ class Pooler(nn.Module): ...@@ -53,8 +56,24 @@ class Pooler(nn.Module):
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if forward_batch.dimensions is not None:
all_same_dimensions = len(set(forward_batch.dimensions)) == 1
if all_same_dimensions:
pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
else:
pooled_data = [
tensor[..., :dim]
for tensor, dim in zip(pooled_data, forward_batch.dimensions)
]
if self.normalize: if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) if isinstance(pooled_data, list):
pooled_data = [
nn.functional.normalize(tensor, p=2, dim=-1)
for tensor in pooled_data
]
else:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
return EmbeddingPoolerOutput(embeddings=pooled_data) return EmbeddingPoolerOutput(embeddings=pooled_data)
......
...@@ -695,6 +695,9 @@ class EmbeddingReqInput(BaseReq): ...@@ -695,6 +695,9 @@ class EmbeddingReqInput(BaseReq):
# tracing context # tracing context
trace_context: Optional[Dict] = None trace_context: Optional[Dict] = None
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
dimensions: Optional[int] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided # at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None: if self.text is None and self.input_ids is None and self.image_data is None:
...@@ -771,6 +774,7 @@ class EmbeddingReqInput(BaseReq): ...@@ -771,6 +774,7 @@ class EmbeddingReqInput(BaseReq):
video_data=self.video_data[i] if self.video_data is not None else None, video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc, http_worker_ipc=self.http_worker_ipc,
) )
...@@ -791,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq): ...@@ -791,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
dimensions: Optional[int] = None
@dataclass @dataclass
......
...@@ -442,6 +442,7 @@ class Req: ...@@ -442,6 +442,7 @@ class Req:
priority: Optional[int] = None, priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None, extra_key: Optional[str] = None,
dimensions: Optional[int] = None,
http_worker_ipc: Optional[str] = None, http_worker_ipc: Optional[str] = None,
): ):
# Input and output info # Input and output info
...@@ -650,6 +651,9 @@ class Req: ...@@ -650,6 +651,9 @@ class Req:
self.tmp_end_idx: int = -1 self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1 self.metadata_buffer_index: int = -1
# For Matryoshka embeddings
self.dimensions = dimensions
@property @property
def seqlen(self): def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
...@@ -1014,6 +1018,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1014,6 +1018,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
encoder_lens_cpu: Optional[List[int]] = None encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None encoder_out_cache_loc: Optional[torch.Tensor] = None
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# For split prefill # For split prefill
split_index: int = 0 split_index: int = 0
split_prefill_finished: bool = False split_prefill_finished: bool = False
...@@ -1177,6 +1184,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1177,6 +1184,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens = [len(r.prefix_indices) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs] extend_lens = [r.extend_input_len for r in reqs]
# For matryoshka embeddings
if self.model_config.is_matryoshka and any(
r.dimensions is not None for r in reqs
):
self.dimensions = [
r.dimensions if r.dimensions else self.model_config.hidden_size
for r in reqs
]
token_type_ids = [ token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None r.token_type_ids for r in reqs if r.token_type_ids is not None
] ]
...@@ -1765,6 +1781,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1765,6 +1781,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
), ),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
is_prefill_only=self.is_prefill_only, is_prefill_only=self.is_prefill_only,
dimensions=self.dimensions,
) )
def copy(self): def copy(self):
...@@ -1873,5 +1890,8 @@ class ModelWorkerBatch: ...@@ -1873,5 +1890,8 @@ class ModelWorkerBatch:
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# Whether this batch is prefill-only (no token generation needed) # Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False is_prefill_only: bool = False
...@@ -1475,6 +1475,7 @@ class Scheduler( ...@@ -1475,6 +1475,7 @@ class Scheduler(
recv_req.sampling_params, recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids, token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority, priority=recv_req.priority,
dimensions=recv_req.dimensions,
http_worker_ipc=recv_req.http_worker_ipc, http_worker_ipc=recv_req.http_worker_ipc,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
...@@ -203,7 +203,10 @@ class SchedulerOutputProcessorMixin: ...@@ -203,7 +203,10 @@ class SchedulerOutputProcessorMixin:
i i
].item() ].item()
else: else:
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.tolist() embeddings = embeddings.tolist()
else:
embeddings = [tensor.tolist() for tensor in embeddings]
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
......
...@@ -666,6 +666,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -666,6 +666,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
raise ValueError(error_msg) raise ValueError(error_msg)
# Matryoshka embeddings validations
if isinstance(obj, EmbeddingReqInput):
self._validate_for_matryoshka_dim(obj)
if isinstance(obj, GenerateReqInput): if isinstance(obj, GenerateReqInput):
if ( if (
obj.return_hidden_states obj.return_hidden_states
...@@ -684,6 +688,34 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -684,6 +688,34 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"Please set `--enable-custom-logit-processor` to enable this feature." "Please set `--enable-custom-logit-processor` to enable this feature."
) )
def _validate_for_matryoshka_dim(self, obj: EmbeddingReqInput) -> None:
"""Validate the request for Matryoshka dim if it has the field set."""
if obj.dimensions is None:
return
if not self.model_config.is_matryoshka:
raise ValueError(
f"Model '{self.model_config.model_path}' does not support matryoshka representation, "
f"changing output dimensions will lead to poor results."
)
if obj.dimensions < 1:
raise ValueError("Requested dimensions must be greater than 0")
if (
self.model_config.matryoshka_dimensions
and obj.dimensions not in self.model_config.matryoshka_dimensions
):
raise ValueError(
f"Model '{self.model_config.model_path}' only supports {self.model_config.matryoshka_dimensions} matryoshka dimensions, "
f"using other output dimensions will lead to poor results."
)
if obj.dimensions > self.model_config.hidden_size:
raise ValueError(
f"Provided dimensions are greater than max embedding dimension: {self.model_config.hidden_size}"
)
def _validate_input_ids_in_vocab( def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int self, input_ids: List[int], vocab_size: int
) -> None: ) -> None:
...@@ -752,6 +784,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -752,6 +784,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
sampling_params, sampling_params,
rid=obj.rid, rid=obj.rid,
priority=obj.priority, priority=obj.priority,
dimensions=obj.dimensions,
http_worker_ipc=obj.http_worker_ipc, http_worker_ipc=obj.http_worker_ipc,
) )
......
...@@ -320,6 +320,9 @@ class ForwardBatch: ...@@ -320,6 +320,9 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None tbo_children: Optional[List[ForwardBatch]] = None
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -361,6 +364,7 @@ class ForwardBatch: ...@@ -361,6 +364,7 @@ class ForwardBatch:
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids, token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index, tbo_split_seq_index=batch.tbo_split_seq_index,
dimensions=batch.dimensions,
) )
device = model_runner.device device = model_runner.device
......
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import json
import multiprocessing as mp import multiprocessing as mp
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -89,7 +90,9 @@ def get_token_ids_logprobs(logits, token_ids): ...@@ -89,7 +90,9 @@ def get_token_ids_logprobs(logits, token_ids):
return logprobs return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype): def _get_sentence_transformer_embedding_model(
model_path, torch_dtype, matryoshka_dim: Optional[int] = None
):
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model from sentence_transformers.util import is_sentence_transformer_model
...@@ -97,6 +100,7 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype): ...@@ -97,6 +100,7 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
model = SentenceTransformer( model = SentenceTransformer(
model_path, model_path,
model_kwargs={"torch_dtype": torch_dtype}, model_kwargs={"torch_dtype": torch_dtype},
truncate_dim=matryoshka_dim,
) )
else: # if no pre-trained sentence-transformers model else: # if no pre-trained sentence-transformers model
from sentence_transformers import models from sentence_transformers import models
...@@ -106,7 +110,9 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype): ...@@ -106,7 +110,9 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
word_embedding_model.get_word_embedding_dimension(), word_embedding_model.get_word_embedding_dimension(),
pooling_mode="lasttoken", pooling_mode="lasttoken",
) )
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) model = SentenceTransformer(
modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
)
return model.cuda() return model.cuda()
...@@ -135,6 +141,7 @@ class HFRunner: ...@@ -135,6 +141,7 @@ class HFRunner:
output_str_only: bool = False, output_str_only: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
patch_model_do_sample_false: bool = False, patch_model_do_sample_false: bool = False,
matryoshka_dim: Optional[int] = None,
): ):
self.model_type = model_type self.model_type = model_type
self.output_str_only = output_str_only self.output_str_only = output_str_only
...@@ -151,6 +158,7 @@ class HFRunner: ...@@ -151,6 +158,7 @@ class HFRunner:
self.out_queue, self.out_queue,
model_path, model_path,
torch_dtype, torch_dtype,
matryoshka_dim,
), ),
) )
self.model_proc.start() self.model_proc.start()
...@@ -225,7 +233,14 @@ class HFRunner: ...@@ -225,7 +233,14 @@ class HFRunner:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous() return embeddings.contiguous()
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): def start_model_process(
self,
in_queue,
out_queue,
model_path,
torch_dtype,
matryoshka_dim: Optional[int] = None,
):
# Apply model-specific patches # Apply model-specific patches
monkey_patch_gemma2_sdpa() monkey_patch_gemma2_sdpa()
...@@ -259,7 +274,7 @@ class HFRunner: ...@@ -259,7 +274,7 @@ class HFRunner:
self.processor = AutoProcessor.from_pretrained(model_path) self.processor = AutoProcessor.from_pretrained(model_path)
else: else:
self.model = _get_sentence_transformer_embedding_model( self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype model_path, torch_dtype, matryoshka_dim=matryoshka_dim
) )
elif self.model_type == "reward" or self.model_type == "cross_encoder": elif self.model_type == "reward" or self.model_type == "cross_encoder":
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
...@@ -519,6 +534,7 @@ class SRTRunner: ...@@ -519,6 +534,7 @@ class SRTRunner:
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
max_loaded_loras: Optional[int] = None, max_loaded_loras: Optional[int] = None,
json_model_override_args: Optional[dict[str, Any]] = None,
lora_eviction_policy: str = "lru", lora_eviction_policy: str = "lru",
): ):
self.model_type = model_type self.model_type = model_type
...@@ -566,6 +582,11 @@ class SRTRunner: ...@@ -566,6 +582,11 @@ class SRTRunner:
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
enable_lora=enable_lora, enable_lora=enable_lora,
max_loaded_loras=max_loaded_loras, max_loaded_loras=max_loaded_loras,
json_model_override_args=(
json.dumps(json_model_override_args)
if json_model_override_args
else "{}"
),
lora_eviction_policy=lora_eviction_policy, lora_eviction_policy=lora_eviction_policy,
**spec_kwargs, **spec_kwargs,
) )
...@@ -594,6 +615,7 @@ class SRTRunner: ...@@ -594,6 +615,7 @@ class SRTRunner:
logprob_start_len: int = 0, logprob_start_len: int = 0,
top_k: Optional[int] = None, top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None, token_ids_logprob: Optional[List[int]] = None,
dimensions: Optional[int] = None,
): ):
if self.is_generation: if self.is_generation:
return self.forward_generation_raw( return self.forward_generation_raw(
...@@ -607,7 +629,9 @@ class SRTRunner: ...@@ -607,7 +629,9 @@ class SRTRunner:
) )
else: else:
if self.model_type == "embedding": if self.model_type == "embedding":
response = self.engine.encode(prompt=prompts, image_data=image_data) response = self.engine.encode(
prompt=prompts, image_data=image_data, dimensions=dimensions
)
if isinstance(response, list): if isinstance(response, list):
logits = [x["embedding"] for x in response] logits = [x["embedding"] for x in response]
else: else:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import multiprocessing as mp import multiprocessing as mp
import random import random
import unittest import unittest
from typing import Optional
import torch import torch
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
...@@ -69,6 +70,7 @@ class TestEmbeddingModels(CustomTestCase): ...@@ -69,6 +70,7 @@ class TestEmbeddingModels(CustomTestCase):
tp_size, tp_size,
torch_dtype, torch_dtype,
prefill_tolerance, prefill_tolerance,
matryoshka_dim: Optional[int] = None,
) -> None: ) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path) truncated_prompts = self._truncate_prompts(prompts, model_path)
...@@ -76,6 +78,7 @@ class TestEmbeddingModels(CustomTestCase): ...@@ -76,6 +78,7 @@ class TestEmbeddingModels(CustomTestCase):
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="embedding", model_type="embedding",
matryoshka_dim=matryoshka_dim,
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(truncated_prompts) hf_outputs = hf_runner.forward(truncated_prompts)
...@@ -86,8 +89,13 @@ class TestEmbeddingModels(CustomTestCase): ...@@ -86,8 +89,13 @@ class TestEmbeddingModels(CustomTestCase):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="embedding", model_type="embedding",
attention_backend=attention_backend, attention_backend=attention_backend,
json_model_override_args=(
{"matryoshka_dimensions": [matryoshka_dim]} if matryoshka_dim else None
),
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(truncated_prompts) srt_outputs = srt_runner.forward(
truncated_prompts, dimensions=matryoshka_dim
)
for i in range(len(prompts)): for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
...@@ -113,6 +121,25 @@ class TestEmbeddingModels(CustomTestCase): ...@@ -113,6 +121,25 @@ class TestEmbeddingModels(CustomTestCase):
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
) )
def test_matryoshka_embedding(self):
models_to_test = [
model
for model in MODELS
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == model[0]
]
assert len(models_to_test) == 1
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
prefill_tolerance,
matryoshka_dim=128,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import json
import os
import unittest import unittest
import numpy as np
import openai import openai
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -92,6 +95,105 @@ class TestOpenAIEmbedding(CustomTestCase): ...@@ -92,6 +95,105 @@ class TestOpenAIEmbedding(CustomTestCase):
# check the status code # check the status code
self.assertEqual(cm.exception.status_code, 400) self.assertEqual(cm.exception.status_code, 400)
def test_embedding_with_dimensions_parameter(self):
"""Test that non-Matryoshka models reject dimensions parameter."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test that specifying dimensions fails for non-Matryoshka models
with self.assertRaises(openai.BadRequestError) as cm:
client.embeddings.create(
model=self.model, input="Hello world", dimensions=512
)
self.assertEqual(cm.exception.status_code, 400)
class TestMatryoshkaEmbeddingModel(CustomTestCase):
"""Test class for Model that supports Matryoshka embedding functionality, using OpenAI API."""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.matryoshka_dims = [128, 256, 512, 768, 1024]
# Configure embedding-specific args with Matryoshka support via json_model_override_args
matryoshka_config = {
"is_matryoshka": True,
"matryoshka_dimensions": cls.matryoshka_dims,
}
other_args = [
"--is-embedding",
"--enable-metrics",
"--json-model-override-args",
json.dumps(matryoshka_config),
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
if hasattr(cls, "process"):
kill_process_tree(cls.process.pid)
def test_matryoshka_embedding_valid_dimensions(self):
"""Test Matryoshka embedding with valid dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test with various valid dimensions
for dimensions in self.matryoshka_dims:
with self.subTest(dimensions=dimensions):
response = client.embeddings.create(
model=self.model, input="Hello world", dimensions=dimensions
)
self.assertEqual(len(response.data), 1)
self.assertEqual(len(response.data[0].embedding), dimensions)
def test_matryoshka_embedding_batch_same_dimensions(self):
"""Test Matryoshka embedding with batch input and same dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=["Hello world", "Test text", "Another example"],
dimensions=256,
)
self.assertEqual(len(response.data), 3)
for embedding_data in response.data:
self.assertEqual(len(embedding_data.embedding), 256)
def test_matryoshka_embedding_no_dimensions(self):
"""Test embedding without specifying dimensions (should use full size)."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input="Hello world")
self.assertEqual(len(response.data), 1)
# Should return full embedding size when no dimensions specified
self.assertEqual(len(response.data[0].embedding), 1536)
def test_matryoshka_embedding_invalid_dimensions(self):
"""Test Matryoshka embedding with invalid dimensions."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
for dimensions in [100, 0, -1, 10000]:
with self.assertRaises(openai.BadRequestError) as cm:
client.embeddings.create(
model=self.model,
input="Hello world",
dimensions=dimensions,
)
self.assertEqual(cm.exception.status_code, 400)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment