Unverified Commit 09988080 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refine OpenAI serving entrypoint to remove batch requests (#7372)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarChang Su <csu272@usc.edu>
parent 794be55a
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
from enum import auto from enum import auto
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import CompletionRequest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
completion_template_name = None completion_template_name = None
...@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool: ...@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool:
return completion_template_name is not None return completion_template_name is not None
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str: def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
global completion_template_name global completion_template_name
if request.suffix == "": if request.suffix == "":
return request.prompt return request.prompt
......
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
import logging import logging
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC): ...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format # Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
request, self._generate_request_id_base(request) request
) )
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC): ...@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: OpenAIServingRequest, request: OpenAIServingRequest,
request_id: str, ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
"""Convert OpenAI request to internal format""" """Convert OpenAI request to internal format"""
pass pass
......
...@@ -54,35 +54,25 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -54,35 +54,25 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return f"All items in input list must be integers" return f"All items in input list must be integers"
if item < 0: if item < 0:
return f"Token ID at index {i} must be non-negative" return f"Token ID at index {i} must be non-negative"
elif isinstance(first_item, list):
# List of lists (multiple token sequences)
for i, item in enumerate(input):
if not isinstance(item, list):
return f"Input at index {i} must be a list"
if not item:
return f"Input at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Input at index {i} must contain only integers"
if any(token < 0 for token in item):
return f"Input at index {i} contains negative token IDs"
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
return None return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
request_id: str, ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format""" """Convert OpenAI embedding request to internal format"""
prompt = request.input prompt = request.input
if isinstance(prompt, str): if isinstance(prompt, str):
# Single string input # Single string input
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list): elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str): if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings # List of strings - if it's a single string in a list, treat as single string
prompt_kwargs = {"text": prompt} if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput): elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs # Handle multimodal embedding inputs
texts = [] texts = []
...@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = [] generate_prompts = []
# Check if we have a chat template for multimodal embeddings # Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr( chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None self.tokenizer_manager, "chat_template_name", None
) )
...@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
else: else:
# Other types (should not happen but handle gracefully) # Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput( adapted_request = EmbeddingReqInput(
**prompt_kwargs, **prompt_kwargs,
) )
......
...@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase):
None, None,
) )
adapted, processed = self.chat._convert_to_internal_request( adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
[self.basic_req], ["rid"]
)
self.assertIsInstance(adapted, GenerateReqInput) self.assertIsInstance(adapted, GenerateReqInput)
self.assertFalse(adapted.stream) self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req) self.assertEqual(processed, self.basic_req)
# ------------- tool-call branch ------------- # # ------------- tool-call branch -------------
def test_tool_call_request_conversion(self): # def test_tool_call_request_conversion(self):
req = ChatCompletionRequest( # req = ChatCompletionRequest(
model="x", # model="x",
messages=[{"role": "user", "content": "Weather?"}], # messages=[{"role": "user", "content": "Weather?"}],
tools=[ # tools=[
{ # {
"type": "function", # "type": "function",
"function": { # "function": {
"name": "get_weather", # "name": "get_weather",
"parameters": {"type": "object", "properties": {}}, # "parameters": {"type": "object", "properties": {}},
}, # },
} # }
], # ],
tool_choice="auto", # tool_choice="auto",
) # )
with patch.object( # with patch.object(
self.chat, # self.chat,
"_process_messages", # "_process_messages",
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None), # return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
): # ):
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
self.assertEqual(adapted.rid, "rid") # self.assertEqual(adapted.rid, "rid")
def test_tool_choice_none(self): # def test_tool_choice_none(self):
req = ChatCompletionRequest( # req = ChatCompletionRequest(
model="x", # model="x",
messages=[{"role": "user", "content": "Hi"}], # messages=[{"role": "user", "content": "Hi"}],
tools=[{"type": "function", "function": {"name": "noop"}}], # tools=[{"type": "function", "function": {"name": "noop"}}],
tool_choice="none", # tool_choice="none",
) # )
with patch.object( # with patch.object(
self.chat, # self.chat,
"_process_messages", # "_process_messages",
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None), # return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
): # ):
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
self.assertEqual(adapted.rid, "rid") # self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch ------------- # ------------- multimodal branch -------------
def test_multimodal_request_with_images(self): def test_multimodal_request_with_images(self):
......
...@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase): ...@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase):
# ---------- prompt-handling ---------- # ---------- prompt-handling ----------
def test_single_string_prompt(self): def test_single_string_prompt(self):
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100) req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "Hello world") self.assertEqual(internal.text, "Hello world")
def test_single_token_ids_prompt(self): def test_single_token_ids_prompt(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100) req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4]) self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self): def test_completion_template_handling(self):
...@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase): ...@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase):
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt", return_value="processed_prompt",
): ):
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt") self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ---------- # ---------- echo-handling ----------
......
...@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase):
def test_convert_single_string_request(self): def test_convert_single_string_request(self):
"""Test converting single string request to internal format.""" """Test converting single string request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.basic_req)
self.basic_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?") self.assertEqual(adapted_request.text, "Hello, how are you?")
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.basic_req) self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self): def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format.""" """Test converting list of strings request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.list_req)
self.list_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual( self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
) )
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req) self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self): def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format.""" """Test converting token IDs request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.token_ids_req)
self.token_ids_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.token_ids_req) self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self): def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format.""" """Test converting multimodal request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.multimodal_req)
self.multimodal_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
...@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIn("World", adapted_request.text) self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data") self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1]) self.assertIsNone(adapted_request.image_data[1])
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self): def test_build_single_embedding_response(self):
"""Test building response for single embedding.""" """Test building response for single embedding."""
...@@ -194,72 +186,86 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -194,72 +186,86 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4 self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7) self.assertEqual(response.usage.total_tokens, 7)
async def test_handle_request_success(self): def test_handle_request_success(self):
"""Test successful embedding request handling.""" """Test successful embedding request handling."""
# Mock the generate_request to return expected data async def run_test():
async def mock_generate(): # Mock the generate_request to return expected data
yield { async def mock_generate():
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5], yield {
"meta_info": {"prompt_tokens": 5}, "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
} "meta_info": {"prompt_tokens": 5},
}
self.serving_embedding.tokenizer_manager.generate_request = Mock( self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate() return_value=mock_generate()
) )
response = await self.serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
self.basic_req, self.request self.basic_req, self.request
) )
self.assertIsInstance(response, EmbeddingResponse) self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
asyncio.run(run_test())
async def test_handle_request_validation_error(self): def test_handle_request_validation_error(self):
"""Test handling request with validation error.""" """Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await self.serving_embedding.handle_request( async def run_test():
invalid_request, self.request invalid_request = EmbeddingRequest(model="test-model", input="")
)
self.assertIsInstance(response, ORJSONResponse) response = await self.serving_embedding.handle_request(
self.assertEqual(response.status_code, 400) invalid_request, self.request
)
async def test_handle_request_generation_error(self): self.assertIsInstance(response, ORJSONResponse)
"""Test handling request with generation error.""" self.assertEqual(response.status_code, 400)
# Mock generate_request to raise an error asyncio.run(run_test())
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
self.serving_embedding.tokenizer_manager.generate_request = Mock( def test_handle_request_generation_error(self):
return_value=mock_generate_error() """Test handling request with generation error."""
)
response = await self.serving_embedding.handle_request( async def run_test():
self.basic_req, self.request # Mock generate_request to raise an error
) async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
self.assertIsInstance(response, ORJSONResponse) self.serving_embedding.tokenizer_manager.generate_request = Mock(
self.assertEqual(response.status_code, 400) return_value=mock_generate_error()
)
async def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with patch.object(
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await self.serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
self.basic_req, self.request self.basic_req, self.request
) )
self.assertIsInstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500) self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
async def run_test():
# Mock _convert_to_internal_request to raise an exception
with patch.object(
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500)
asyncio.run(run_test())
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment