Unverified Commit 4df5fc21 authored by woodx's avatar woodx Committed by GitHub
Browse files

Feat/refactor embedding server (#7322)

parent a06912ad
...@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import ( ...@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_prometheus_middleware, add_prometheus_middleware,
...@@ -64,6 +65,7 @@ class AppState: ...@@ -64,6 +65,7 @@ class AppState:
server_args: Optional[ServerArgs] = None server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager @asynccontextmanager
...@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI): ...@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics: if server_args.enable_metrics:
add_prometheus_middleware(app) add_prometheus_middleware(app)
...@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request): ...@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
@app.post("/v1/embeddings") @app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request): async def openai_v1_embeddings(raw_request: Request):
pass try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score") @app.post("/v1/score")
......
...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC): ...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format # Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
[request], [self._generate_request_id_base(request)] request, self._generate_request_id_base(request)
) )
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC): ...@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
@abstractmethod @abstractmethod
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
all_requests: List[OpenAIServingRequest], request: OpenAIServingRequest,
request_ids: List[str], request_id: str,
) -> tuple[ ) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]] GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]: ]:
......
...@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
all_requests: List[EmbeddingRequest], request: EmbeddingRequest,
request_ids: List[str], request_id: str,
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]: ) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format""" """Convert OpenAI embedding request to internal format"""
prompts = [request.input for request in all_requests] prompt = request.input
if isinstance(prompt, str):
# Handle single vs multiple requests # Single string input
if len(all_requests) == 1: prompt_kwargs = {"text": prompt}
prompt = prompts[0] elif isinstance(prompt, list):
if isinstance(prompt, str): if len(prompt) > 0 and isinstance(prompt[0], str):
# Single string input # List of strings
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list): elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
if len(prompt) > 0 and isinstance(prompt[0], str): # Handle multimodal embedding inputs
# List of strings texts = []
prompt_kwargs = {"text": prompt} images = []
elif len(prompt) > 0 and isinstance( for item in prompt:
prompt[0], MultimodalEmbeddingInput # Use padding for text if None - this could be improved
): texts.append(item.text if item.text is not None else "padding")
# Handle multimodal embedding inputs images.append(item.image if item.image is not None else None)
texts = []
images = [] generate_prompts = []
for item in prompt: # Check if we have a chat template for multimodal embeddings
# Use padding for text if None - this could be improved # This would need to be passed in from the server configuration
texts.append(item.text if item.text is not None else "padding") chat_template_name = getattr(
images.append(item.image if item.image is not None else None) self.tokenizer_manager, "chat_template_name", None
)
generate_prompts = [] if chat_template_name is not None:
# Check if we have a chat template for multimodal embeddings convs = generate_embedding_convs(texts, images, chat_template_name)
# This would need to be passed in from the server configuration for conv in convs:
chat_template_name = getattr( generate_prompts.append(conv.get_prompt())
self.tokenizer_manager, "chat_template_name", None else:
) generate_prompts = texts
if chat_template_name is not None:
convs = generate_embedding_convs( if len(generate_prompts) == 1:
texts, images, chat_template_name prompt_kwargs = {
) "text": generate_prompts[0],
for conv in convs: "image_data": images[0],
generate_prompts.append(conv.get_prompt()) }
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
else: else:
# List of integers (token IDs) or empty list prompt_kwargs = {
prompt_kwargs = {"input_ids": prompt} "text": generate_prompts,
"image_data": images,
}
else: else:
# Other types (should not happen but handle gracefully) # List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
# Use the passed request_ids for single request
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
else: else:
# Handle batch requests # Other types (should not happen but handle gracefully)
if len(prompts) > 0: prompt_kwargs = {"input_ids": prompt}
# Validate that all prompts have the same type
first_prompt = prompts[0]
first_type = type(first_prompt)
for i, prompt in enumerate(prompts[1:], 1):
if type(prompt) != first_type:
raise AssertionError(
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
)
if isinstance(first_prompt, str):
# Batch of strings
prompt_kwargs = {"text": prompts}
elif isinstance(first_prompt, list):
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
# Batch of lists of strings
prompt_kwargs = {"text": prompts}
elif len(first_prompt) > 0 and isinstance(
first_prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else:
# Batch of token ID lists
prompt_kwargs = {"input_ids": prompts}
else:
# Other types
prompt_kwargs = {"input_ids": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
# Use the passed request_ids for batch requests
final_request_id = request_ids
adapted_request = EmbeddingReqInput( adapted_request = EmbeddingReqInput(
rid=final_request_id,
**prompt_kwargs, **prompt_kwargs,
) )
return adapted_request, ( return adapted_request, request
all_requests[0] if len(all_requests) == 1 else all_requests
)
async def _handle_non_streaming_request( async def _handle_non_streaming_request(
self, self,
...@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
response = self._build_embedding_response( response = self._build_embedding_response(ret)
ret, self.tokenizer_manager.model_path
)
return response return response
def _build_embedding_response( def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
self, ret: List[Dict[str, Any]], model_path: str
) -> EmbeddingResponse:
"""Build the embedding response""" """Build the embedding response"""
embedding_objects = [] embedding_objects = []
prompt_tokens = 0 prompt_tokens = 0
...@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return EmbeddingResponse( return EmbeddingResponse(
data=embedding_objects, data=embedding_objects,
model=model_path, model=self.tokenizer_manager.model_path,
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens, total_tokens=prompt_tokens,
......
...@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
"""Test converting single string request to internal format.""" """Test converting single string request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[self.basic_req], ["test-id"] self.basic_req, "test-id"
) )
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?") self.assertEqual(adapted_request.text, "Hello, how are you?")
self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.basic_req) self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self): def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format.""" """Test converting list of strings request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[self.list_req], ["test-id"] self.list_req, "test-id"
) )
) )
...@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertEqual( self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
) )
self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.list_req) self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self): def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format.""" """Test converting token IDs request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[self.token_ids_req], ["test-id"] self.token_ids_req, "test-id"
) )
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.token_ids_req) self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self): def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format.""" """Test converting multimodal request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[self.multimodal_req], ["test-id"] self.multimodal_req, "test-id"
) )
) )
...@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIn("World", adapted_request.text) self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data") self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1]) self.assertIsNone(adapted_request.image_data[1])
self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(adapted_request.rid, None)
def test_build_single_embedding_response(self): def test_build_single_embedding_response(self):
"""Test building response for single embedding.""" """Test building response for single embedding."""
...@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
} }
] ]
response = self.serving_embedding._build_embedding_response( response = self.serving_embedding._build_embedding_response(ret_data)
ret_data, "test-model"
)
self.assertIsInstance(response, EmbeddingResponse) self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model") self.assertEqual(response.model, "test-model")
...@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
}, },
] ]
response = self.serving_embedding._build_embedding_response( response = self.serving_embedding._build_embedding_response(ret_data)
ret_data, "test-model"
)
self.assertIsInstance(response, EmbeddingResponse) self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2) self.assertEqual(len(response.data), 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment