Unverified Commit ffd1a26e authored by Jinn's avatar Jinn Committed by GitHub
Browse files

Add more refactored openai test & in CI (#7284)

parent 09ae5b20
...@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response from fastapi.responses import Response
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
FakeBootstrapHost, FAKE_BOOTSTRAP_HOST,
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
...@@ -265,7 +265,7 @@ def _wait_and_warmup( ...@@ -265,7 +265,7 @@ def _wait_and_warmup(
"max_new_tokens": 8, "max_new_tokens": 8,
"ignore_eos": True, "ignore_eos": True,
}, },
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup # This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup # ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [ "bootstrap_room": [
......
...@@ -12,9 +12,10 @@ import pytest ...@@ -12,9 +12,10 @@ import pytest
import requests import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server" SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = "dummy-model" DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120)) STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
...@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No ...@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
def launch_openai_server(model: str = DEFAULT_MODEL, **kw): def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until its ready.""" """Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port() port = _pick_free_port()
cmd = [ cmd = [
sys.executable, sys.executable,
...@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw): ...@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]: def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the servers base URL and cleans up.""" """PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server() proc, base, log_file = launch_openai_server()
yield base yield base
kill_process_tree(proc.pid) kill_process_tree(proc.pid)
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import json import json
import time import time
import unittest
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pytest
from pydantic import ValidationError from pydantic import ValidationError
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
...@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
) )
class TestModelCard: class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model""" """Test ModelCard protocol model"""
def test_basic_model_card_creation(self): def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields""" """Test basic model card creation with required fields"""
card = ModelCard(id="test-model") card = ModelCard(id="test-model")
assert card.id == "test-model" self.assertEqual(card.id, "test-model")
assert card.object == "model" self.assertEqual(card.object, "model")
assert card.owned_by == "sglang" self.assertEqual(card.owned_by, "sglang")
assert isinstance(card.created, int) self.assertIsInstance(card.created, int)
assert card.root is None self.assertIsNone(card.root)
assert card.max_model_len is None self.assertIsNone(card.max_model_len)
def test_model_card_with_optional_fields(self): def test_model_card_with_optional_fields(self):
"""Test model card with optional fields""" """Test model card with optional fields"""
...@@ -85,28 +85,28 @@ class TestModelCard: ...@@ -85,28 +85,28 @@ class TestModelCard:
max_model_len=2048, max_model_len=2048,
created=1234567890, created=1234567890,
) )
assert card.id == "test-model" self.assertEqual(card.id, "test-model")
assert card.root == "/path/to/model" self.assertEqual(card.root, "/path/to/model")
assert card.max_model_len == 2048 self.assertEqual(card.max_model_len, 2048)
assert card.created == 1234567890 self.assertEqual(card.created, 1234567890)
def test_model_card_serialization(self): def test_model_card_serialization(self):
"""Test model card JSON serialization""" """Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096) card = ModelCard(id="test-model", max_model_len=4096)
data = card.model_dump() data = card.model_dump()
assert data["id"] == "test-model" self.assertEqual(data["id"], "test-model")
assert data["object"] == "model" self.assertEqual(data["object"], "model")
assert data["max_model_len"] == 4096 self.assertEqual(data["max_model_len"], 4096)
class TestModelList: class TestModelList(unittest.TestCase):
"""Test ModelList protocol model""" """Test ModelList protocol model"""
def test_empty_model_list(self): def test_empty_model_list(self):
"""Test empty model list creation""" """Test empty model list creation"""
model_list = ModelList() model_list = ModelList()
assert model_list.object == "list" self.assertEqual(model_list.object, "list")
assert len(model_list.data) == 0 self.assertEqual(len(model_list.data), 0)
def test_model_list_with_cards(self): def test_model_list_with_cards(self):
"""Test model list with model cards""" """Test model list with model cards"""
...@@ -115,12 +115,12 @@ class TestModelList: ...@@ -115,12 +115,12 @@ class TestModelList:
ModelCard(id="model-2", max_model_len=2048), ModelCard(id="model-2", max_model_len=2048),
] ]
model_list = ModelList(data=cards) model_list = ModelList(data=cards)
assert len(model_list.data) == 2 self.assertEqual(len(model_list.data), 2)
assert model_list.data[0].id == "model-1" self.assertEqual(model_list.data[0].id, "model-1")
assert model_list.data[1].id == "model-2" self.assertEqual(model_list.data[1].id, "model-2")
class TestErrorResponse: class TestErrorResponse(unittest.TestCase):
"""Test ErrorResponse protocol model""" """Test ErrorResponse protocol model"""
def test_basic_error_response(self): def test_basic_error_response(self):
...@@ -128,11 +128,11 @@ class TestErrorResponse: ...@@ -128,11 +128,11 @@ class TestErrorResponse:
error = ErrorResponse( error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400 message="Invalid request", type="BadRequestError", code=400
) )
assert error.object == "error" self.assertEqual(error.object, "error")
assert error.message == "Invalid request" self.assertEqual(error.message, "Invalid request")
assert error.type == "BadRequestError" self.assertEqual(error.type, "BadRequestError")
assert error.code == 400 self.assertEqual(error.code, 400)
assert error.param is None self.assertIsNone(error.param)
def test_error_response_with_param(self): def test_error_response_with_param(self):
"""Test error response with parameter""" """Test error response with parameter"""
...@@ -142,19 +142,19 @@ class TestErrorResponse: ...@@ -142,19 +142,19 @@ class TestErrorResponse:
code=422, code=422,
param="temperature", param="temperature",
) )
assert error.param == "temperature" self.assertEqual(error.param, "temperature")
class TestUsageInfo: class TestUsageInfo(unittest.TestCase):
"""Test UsageInfo protocol model""" """Test UsageInfo protocol model"""
def test_basic_usage_info(self): def test_basic_usage_info(self):
"""Test basic usage info creation""" """Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30) usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
assert usage.prompt_tokens == 10 self.assertEqual(usage.prompt_tokens, 10)
assert usage.completion_tokens == 20 self.assertEqual(usage.completion_tokens, 20)
assert usage.total_tokens == 30 self.assertEqual(usage.total_tokens, 30)
assert usage.prompt_tokens_details is None self.assertIsNone(usage.prompt_tokens_details)
def test_usage_info_with_cache_details(self): def test_usage_info_with_cache_details(self):
"""Test usage info with cache details""" """Test usage info with cache details"""
...@@ -164,22 +164,22 @@ class TestUsageInfo: ...@@ -164,22 +164,22 @@ class TestUsageInfo:
total_tokens=30, total_tokens=30,
prompt_tokens_details={"cached_tokens": 5}, prompt_tokens_details={"cached_tokens": 5},
) )
assert usage.prompt_tokens_details == {"cached_tokens": 5} self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
class TestCompletionRequest: class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model""" """Test CompletionRequest protocol model"""
def test_basic_completion_request(self): def test_basic_completion_request(self):
"""Test basic completion request""" """Test basic completion request"""
request = CompletionRequest(model="test-model", prompt="Hello world") request = CompletionRequest(model="test-model", prompt="Hello world")
assert request.model == "test-model" self.assertEqual(request.model, "test-model")
assert request.prompt == "Hello world" self.assertEqual(request.prompt, "Hello world")
assert request.max_tokens == 16 # default self.assertEqual(request.max_tokens, 16) # default
assert request.temperature == 1.0 # default self.assertEqual(request.temperature, 1.0) # default
assert request.n == 1 # default self.assertEqual(request.n, 1) # default
assert not request.stream # default self.assertFalse(request.stream) # default
assert not request.echo # default self.assertFalse(request.echo) # default
def test_completion_request_with_options(self): def test_completion_request_with_options(self):
"""Test completion request with various options""" """Test completion request with various options"""
...@@ -195,15 +195,15 @@ class TestCompletionRequest: ...@@ -195,15 +195,15 @@ class TestCompletionRequest:
stop=[".", "!"], stop=[".", "!"],
logprobs=5, logprobs=5,
) )
assert request.prompt == ["Hello", "world"] self.assertEqual(request.prompt, ["Hello", "world"])
assert request.max_tokens == 100 self.assertEqual(request.max_tokens, 100)
assert request.temperature == 0.7 self.assertEqual(request.temperature, 0.7)
assert request.top_p == 0.9 self.assertEqual(request.top_p, 0.9)
assert request.n == 2 self.assertEqual(request.n, 2)
assert request.stream self.assertTrue(request.stream)
assert request.echo self.assertTrue(request.echo)
assert request.stop == [".", "!"] self.assertEqual(request.stop, [".", "!"])
assert request.logprobs == 5 self.assertEqual(request.logprobs, 5)
def test_completion_request_sglang_extensions(self): def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions""" """Test completion request with SGLang-specific extensions"""
...@@ -217,23 +217,23 @@ class TestCompletionRequest: ...@@ -217,23 +217,23 @@ class TestCompletionRequest:
json_schema='{"type": "object"}', json_schema='{"type": "object"}',
lora_path="/path/to/lora", lora_path="/path/to/lora",
) )
assert request.top_k == 50 self.assertEqual(request.top_k, 50)
assert request.min_p == 0.1 self.assertEqual(request.min_p, 0.1)
assert request.repetition_penalty == 1.1 self.assertEqual(request.repetition_penalty, 1.1)
assert request.regex == r"\d+" self.assertEqual(request.regex, r"\d+")
assert request.json_schema == '{"type": "object"}' self.assertEqual(request.json_schema, '{"type": "object"}')
assert request.lora_path == "/path/to/lora" self.assertEqual(request.lora_path, "/path/to/lora")
def test_completion_request_validation_errors(self): def test_completion_request_validation_errors(self):
"""Test completion request validation errors""" """Test completion request validation errors"""
with pytest.raises(ValidationError): with self.assertRaises(ValidationError):
CompletionRequest() # missing required fields CompletionRequest() # missing required fields
with pytest.raises(ValidationError): with self.assertRaises(ValidationError):
CompletionRequest(model="test-model") # missing prompt CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse: class TestCompletionResponse(unittest.TestCase):
"""Test CompletionResponse protocol model""" """Test CompletionResponse protocol model"""
def test_basic_completion_response(self): def test_basic_completion_response(self):
...@@ -245,28 +245,28 @@ class TestCompletionResponse: ...@@ -245,28 +245,28 @@ class TestCompletionResponse:
response = CompletionResponse( response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage id="test-id", model="test-model", choices=[choice], usage=usage
) )
assert response.id == "test-id" self.assertEqual(response.id, "test-id")
assert response.object == "text_completion" self.assertEqual(response.object, "text_completion")
assert response.model == "test-model" self.assertEqual(response.model, "test-model")
assert len(response.choices) == 1 self.assertEqual(len(response.choices), 1)
assert response.choices[0].text == "Hello world!" self.assertEqual(response.choices[0].text, "Hello world!")
assert response.usage.total_tokens == 5 self.assertEqual(response.usage.total_tokens, 5)
class TestChatCompletionRequest: class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model""" """Test ChatCompletionRequest protocol model"""
def test_basic_chat_completion_request(self): def test_basic_chat_completion_request(self):
"""Test basic chat completion request""" """Test basic chat completion request"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(model="test-model", messages=messages) request = ChatCompletionRequest(model="test-model", messages=messages)
assert request.model == "test-model" self.assertEqual(request.model, "test-model")
assert len(request.messages) == 1 self.assertEqual(len(request.messages), 1)
assert request.messages[0].role == "user" self.assertEqual(request.messages[0].role, "user")
assert request.messages[0].content == "Hello" self.assertEqual(request.messages[0].content, "Hello")
assert request.temperature == 0.7 # default self.assertEqual(request.temperature, 0.7) # default
assert not request.stream # default self.assertFalse(request.stream) # default
assert request.tool_choice == "none" # default when no tools self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_with_multimodal_content(self): def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content""" """Test chat completion with multimodal content"""
...@@ -283,9 +283,9 @@ class TestChatCompletionRequest: ...@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
} }
] ]
request = ChatCompletionRequest(model="test-model", messages=messages) request = ChatCompletionRequest(model="test-model", messages=messages)
assert len(request.messages[0].content) == 2 self.assertEqual(len(request.messages[0].content), 2)
assert request.messages[0].content[0].type == "text" self.assertEqual(request.messages[0].content[0].type, "text")
assert request.messages[0].content[1].type == "image_url" self.assertEqual(request.messages[0].content[1].type, "image_url")
def test_chat_completion_with_tools(self): def test_chat_completion_with_tools(self):
"""Test chat completion with tools""" """Test chat completion with tools"""
...@@ -306,9 +306,9 @@ class TestChatCompletionRequest: ...@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools model="test-model", messages=messages, tools=tools
) )
assert len(request.tools) == 1 self.assertEqual(len(request.tools), 1)
assert request.tools[0].function.name == "get_weather" self.assertEqual(request.tools[0].function.name, "get_weather")
assert request.tool_choice == "auto" # default when tools present self.assertEqual(request.tool_choice, "auto") # default when tools present
def test_chat_completion_tool_choice_validation(self): def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic""" """Test tool choice validation logic"""
...@@ -316,7 +316,7 @@ class TestChatCompletionRequest: ...@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
# No tools, tool_choice should default to "none" # No tools, tool_choice should default to "none"
request1 = ChatCompletionRequest(model="test-model", messages=messages) request1 = ChatCompletionRequest(model="test-model", messages=messages)
assert request1.tool_choice == "none" self.assertEqual(request1.tool_choice, "none")
# With tools, tool_choice should default to "auto" # With tools, tool_choice should default to "auto"
tools = [ tools = [
...@@ -328,7 +328,7 @@ class TestChatCompletionRequest: ...@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
request2 = ChatCompletionRequest( request2 = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools model="test-model", messages=messages, tools=tools
) )
assert request2.tool_choice == "auto" self.assertEqual(request2.tool_choice, "auto")
def test_chat_completion_sglang_extensions(self): def test_chat_completion_sglang_extensions(self):
"""Test chat completion with SGLang extensions""" """Test chat completion with SGLang extensions"""
...@@ -342,14 +342,14 @@ class TestChatCompletionRequest: ...@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
stream_reasoning=False, stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"}, chat_template_kwargs={"custom_param": "value"},
) )
assert request.top_k == 40 self.assertEqual(request.top_k, 40)
assert request.min_p == 0.05 self.assertEqual(request.min_p, 0.05)
assert not request.separate_reasoning self.assertFalse(request.separate_reasoning)
assert not request.stream_reasoning self.assertFalse(request.stream_reasoning)
assert request.chat_template_kwargs == {"custom_param": "value"} self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
class TestChatCompletionResponse: class TestChatCompletionResponse(unittest.TestCase):
"""Test ChatCompletionResponse protocol model""" """Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self): def test_basic_chat_completion_response(self):
...@@ -362,11 +362,11 @@ class TestChatCompletionResponse: ...@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
response = ChatCompletionResponse( response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage id="test-id", model="test-model", choices=[choice], usage=usage
) )
assert response.id == "test-id" self.assertEqual(response.id, "test-id")
assert response.object == "chat.completion" self.assertEqual(response.object, "chat.completion")
assert response.model == "test-model" self.assertEqual(response.model, "test-model")
assert len(response.choices) == 1 self.assertEqual(len(response.choices), 1)
assert response.choices[0].message.content == "Hello there!" self.assertEqual(response.choices[0].message.content, "Hello there!")
def test_chat_completion_response_with_tool_calls(self): def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls""" """Test chat completion response with tool calls"""
...@@ -384,28 +384,30 @@ class TestChatCompletionResponse: ...@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
response = ChatCompletionResponse( response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage id="test-id", model="test-model", choices=[choice], usage=usage
) )
assert response.choices[0].message.tool_calls[0].function.name == "get_weather" self.assertEqual(
assert response.choices[0].finish_reason == "tool_calls" response.choices[0].message.tool_calls[0].function.name, "get_weather"
)
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
class TestEmbeddingRequest: class TestEmbeddingRequest(unittest.TestCase):
"""Test EmbeddingRequest protocol model""" """Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self): def test_basic_embedding_request(self):
"""Test basic embedding request""" """Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world") request = EmbeddingRequest(model="test-model", input="Hello world")
assert request.model == "test-model" self.assertEqual(request.model, "test-model")
assert request.input == "Hello world" self.assertEqual(request.input, "Hello world")
assert request.encoding_format == "float" # default self.assertEqual(request.encoding_format, "float") # default
assert request.dimensions is None # default self.assertIsNone(request.dimensions) # default
def test_embedding_request_with_list_input(self): def test_embedding_request_with_list_input(self):
"""Test embedding request with list input""" """Test embedding request with list input"""
request = EmbeddingRequest( request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512 model="test-model", input=["Hello", "world"], dimensions=512
) )
assert request.input == ["Hello", "world"] self.assertEqual(request.input, ["Hello", "world"])
assert request.dimensions == 512 self.assertEqual(request.dimensions, 512)
def test_multimodal_embedding_request(self): def test_multimodal_embedding_request(self):
"""Test multimodal embedding request""" """Test multimodal embedding request"""
...@@ -414,14 +416,14 @@ class TestEmbeddingRequest: ...@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
MultimodalEmbeddingInput(text="World", image=None), MultimodalEmbeddingInput(text="World", image=None),
] ]
request = EmbeddingRequest(model="test-model", input=multimodal_input) request = EmbeddingRequest(model="test-model", input=multimodal_input)
assert len(request.input) == 2 self.assertEqual(len(request.input), 2)
assert request.input[0].text == "Hello" self.assertEqual(request.input[0].text, "Hello")
assert request.input[0].image == "base64_image_data" self.assertEqual(request.input[0].image, "base64_image_data")
assert request.input[1].text == "World" self.assertEqual(request.input[1].text, "World")
assert request.input[1].image is None self.assertIsNone(request.input[1].image)
class TestEmbeddingResponse: class TestEmbeddingResponse(unittest.TestCase):
"""Test EmbeddingResponse protocol model""" """Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self): def test_basic_embedding_response(self):
...@@ -431,14 +433,14 @@ class TestEmbeddingResponse: ...@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
response = EmbeddingResponse( response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage data=[embedding_obj], model="test-model", usage=usage
) )
assert response.object == "list" self.assertEqual(response.object, "list")
assert len(response.data) == 1 self.assertEqual(len(response.data), 1)
assert response.data[0].embedding == [0.1, 0.2, 0.3] self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
assert response.data[0].index == 0 self.assertEqual(response.data[0].index, 0)
assert response.usage.prompt_tokens == 3 self.assertEqual(response.usage.prompt_tokens, 3)
class TestScoringRequest: class TestScoringRequest(unittest.TestCase):
"""Test ScoringRequest protocol model""" """Test ScoringRequest protocol model"""
def test_basic_scoring_request(self): def test_basic_scoring_request(self):
...@@ -446,11 +448,11 @@ class TestScoringRequest: ...@@ -446,11 +448,11 @@ class TestScoringRequest:
request = ScoringRequest( request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"] model="test-model", query="Hello", items=["World", "Earth"]
) )
assert request.model == "test-model" self.assertEqual(request.model, "test-model")
assert request.query == "Hello" self.assertEqual(request.query, "Hello")
assert request.items == ["World", "Earth"] self.assertEqual(request.items, ["World", "Earth"])
assert not request.apply_softmax # default self.assertFalse(request.apply_softmax) # default
assert not request.item_first # default self.assertFalse(request.item_first) # default
def test_scoring_request_with_token_ids(self): def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs""" """Test scoring request with token IDs"""
...@@ -462,34 +464,34 @@ class TestScoringRequest: ...@@ -462,34 +464,34 @@ class TestScoringRequest:
apply_softmax=True, apply_softmax=True,
item_first=True, item_first=True,
) )
assert request.query == [1, 2, 3] self.assertEqual(request.query, [1, 2, 3])
assert request.items == [[4, 5], [6, 7]] self.assertEqual(request.items, [[4, 5], [6, 7]])
assert request.label_token_ids == [8, 9] self.assertEqual(request.label_token_ids, [8, 9])
assert request.apply_softmax self.assertTrue(request.apply_softmax)
assert request.item_first self.assertTrue(request.item_first)
class TestScoringResponse: class TestScoringResponse(unittest.TestCase):
"""Test ScoringResponse protocol model""" """Test ScoringResponse protocol model"""
def test_basic_scoring_response(self): def test_basic_scoring_response(self):
"""Test basic scoring response""" """Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model") response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
assert response.object == "scoring" self.assertEqual(response.object, "scoring")
assert response.scores == [[0.1, 0.9], [0.3, 0.7]] self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
assert response.model == "test-model" self.assertEqual(response.model, "test-model")
assert response.usage is None # default self.assertIsNone(response.usage) # default
class TestFileOperations: class TestFileOperations(unittest.TestCase):
"""Test file operation protocol models""" """Test file operation protocol models"""
def test_file_request(self): def test_file_request(self):
"""Test file request model""" """Test file request model"""
file_data = b"test file content" file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch") request = FileRequest(file=file_data, purpose="batch")
assert request.file == file_data self.assertEqual(request.file, file_data)
assert request.purpose == "batch" self.assertEqual(request.purpose, "batch")
def test_file_response(self): def test_file_response(self):
"""Test file response model""" """Test file response model"""
...@@ -500,20 +502,20 @@ class TestFileOperations: ...@@ -500,20 +502,20 @@ class TestFileOperations:
filename="test.jsonl", filename="test.jsonl",
purpose="batch", purpose="batch",
) )
assert response.id == "file-123" self.assertEqual(response.id, "file-123")
assert response.object == "file" self.assertEqual(response.object, "file")
assert response.bytes == 1024 self.assertEqual(response.bytes, 1024)
assert response.filename == "test.jsonl" self.assertEqual(response.filename, "test.jsonl")
def test_file_delete_response(self): def test_file_delete_response(self):
"""Test file delete response model""" """Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True) response = FileDeleteResponse(id="file-123", deleted=True)
assert response.id == "file-123" self.assertEqual(response.id, "file-123")
assert response.object == "file" self.assertEqual(response.object, "file")
assert response.deleted self.assertTrue(response.deleted)
class TestBatchOperations: class TestBatchOperations(unittest.TestCase):
"""Test batch operation protocol models""" """Test batch operation protocol models"""
def test_batch_request(self): def test_batch_request(self):
...@@ -524,10 +526,10 @@ class TestBatchOperations: ...@@ -524,10 +526,10 @@ class TestBatchOperations:
completion_window="24h", completion_window="24h",
metadata={"custom": "value"}, metadata={"custom": "value"},
) )
assert request.input_file_id == "file-123" self.assertEqual(request.input_file_id, "file-123")
assert request.endpoint == "/v1/chat/completions" self.assertEqual(request.endpoint, "/v1/chat/completions")
assert request.completion_window == "24h" self.assertEqual(request.completion_window, "24h")
assert request.metadata == {"custom": "value"} self.assertEqual(request.metadata, {"custom": "value"})
def test_batch_response(self): def test_batch_response(self):
"""Test batch response model""" """Test batch response model"""
...@@ -538,20 +540,20 @@ class TestBatchOperations: ...@@ -538,20 +540,20 @@ class TestBatchOperations:
completion_window="24h", completion_window="24h",
created_at=1234567890, created_at=1234567890,
) )
assert response.id == "batch-123" self.assertEqual(response.id, "batch-123")
assert response.object == "batch" self.assertEqual(response.object, "batch")
assert response.status == "validating" # default self.assertEqual(response.status, "validating") # default
assert response.endpoint == "/v1/chat/completions" self.assertEqual(response.endpoint, "/v1/chat/completions")
class TestResponseFormats: class TestResponseFormats(unittest.TestCase):
"""Test response format protocol models""" """Test response format protocol models"""
def test_basic_response_format(self): def test_basic_response_format(self):
"""Test basic response format""" """Test basic response format"""
format_obj = ResponseFormat(type="json_object") format_obj = ResponseFormat(type="json_object")
assert format_obj.type == "json_object" self.assertEqual(format_obj.type, "json_object")
assert format_obj.json_schema is None self.assertIsNone(format_obj.json_schema)
def test_json_schema_response_format(self): def test_json_schema_response_format(self):
"""Test JSON schema response format""" """Test JSON schema response format"""
...@@ -560,9 +562,9 @@ class TestResponseFormats: ...@@ -560,9 +562,9 @@ class TestResponseFormats:
name="person_schema", description="Person schema", schema=schema name="person_schema", description="Person schema", schema=schema
) )
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema) format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
assert format_obj.type == "json_schema" self.assertEqual(format_obj.type, "json_schema")
assert format_obj.json_schema.name == "person_schema" self.assertEqual(format_obj.json_schema.name, "person_schema")
assert format_obj.json_schema.schema_ == schema self.assertEqual(format_obj.json_schema.schema_, schema)
def test_structural_tag_response_format(self): def test_structural_tag_response_format(self):
"""Test structural tag response format""" """Test structural tag response format"""
...@@ -576,12 +578,12 @@ class TestResponseFormats: ...@@ -576,12 +578,12 @@ class TestResponseFormats:
format_obj = StructuralTagResponseFormat( format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"] type="structural_tag", structures=structures, triggers=["think"]
) )
assert format_obj.type == "structural_tag" self.assertEqual(format_obj.type, "structural_tag")
assert len(format_obj.structures) == 1 self.assertEqual(len(format_obj.structures), 1)
assert format_obj.triggers == ["think"] self.assertEqual(format_obj.triggers, ["think"])
class TestLogProbs: class TestLogProbs(unittest.TestCase):
"""Test LogProbs protocol models""" """Test LogProbs protocol models"""
def test_basic_logprobs(self): def test_basic_logprobs(self):
...@@ -592,9 +594,9 @@ class TestLogProbs: ...@@ -592,9 +594,9 @@ class TestLogProbs:
tokens=["Hello", " ", "world"], tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}], top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
) )
assert len(logprobs.tokens) == 3 self.assertEqual(len(logprobs.tokens), 3)
assert logprobs.tokens == ["Hello", " ", "world"] self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
assert logprobs.token_logprobs == [-0.1, -0.2, -0.3] self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
def test_choice_logprobs(self): def test_choice_logprobs(self):
"""Test ChoiceLogprobs model""" """Test ChoiceLogprobs model"""
...@@ -607,17 +609,17 @@ class TestLogProbs: ...@@ -607,17 +609,17 @@ class TestLogProbs:
], ],
) )
choice_logprobs = ChoiceLogprobs(content=[token_logprob]) choice_logprobs = ChoiceLogprobs(content=[token_logprob])
assert len(choice_logprobs.content) == 1 self.assertEqual(len(choice_logprobs.content), 1)
assert choice_logprobs.content[0].token == "Hello" self.assertEqual(choice_logprobs.content[0].token, "Hello")
class TestStreamingModels: class TestStreamingModels(unittest.TestCase):
"""Test streaming response models""" """Test streaming response models"""
def test_stream_options(self): def test_stream_options(self):
"""Test StreamOptions model""" """Test StreamOptions model"""
options = StreamOptions(include_usage=True) options = StreamOptions(include_usage=True)
assert options.include_usage self.assertTrue(options.include_usage)
def test_chat_completion_stream_response(self): def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model""" """Test ChatCompletionStreamResponse model"""
...@@ -626,29 +628,29 @@ class TestStreamingModels: ...@@ -626,29 +628,29 @@ class TestStreamingModels:
response = ChatCompletionStreamResponse( response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice] id="test-id", model="test-model", choices=[choice]
) )
assert response.object == "chat.completion.chunk" self.assertEqual(response.object, "chat.completion.chunk")
assert response.choices[0].delta.content == "Hello" self.assertEqual(response.choices[0].delta.content, "Hello")
class TestValidationEdgeCases: class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios""" """Test edge cases and validation scenarios"""
def test_empty_messages_validation(self): def test_empty_messages_validation(self):
"""Test validation with empty messages""" """Test validation with empty messages"""
with pytest.raises(ValidationError): with self.assertRaises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[]) ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self): def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type""" """Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValidationError): with self.assertRaises(ValidationError):
ChatCompletionRequest( ChatCompletionRequest(
model="test-model", messages=messages, tool_choice=123 model="test-model", messages=messages, tool_choice=123
) )
def test_negative_token_limits(self): def test_negative_token_limits(self):
"""Test negative token limits""" """Test negative token limits"""
with pytest.raises(ValidationError): with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self): def test_invalid_temperature_range(self):
...@@ -656,7 +658,7 @@ class TestValidationEdgeCases: ...@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
# Note: The current protocol doesn't enforce temperature range, # Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior # but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0) request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
assert request.temperature == 5.0 # Currently allowed self.assertEqual(request.temperature, 5.0) # Currently allowed
def test_model_serialization_roundtrip(self): def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized""" """Test that models can be serialized and deserialized"""
...@@ -673,11 +675,11 @@ class TestValidationEdgeCases: ...@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
# Deserialize back # Deserialize back
restored_request = ChatCompletionRequest(**data) restored_request = ChatCompletionRequest(**data)
assert restored_request.model == original_request.model self.assertEqual(restored_request.model, original_request.model)
assert restored_request.temperature == original_request.temperature self.assertEqual(restored_request.temperature, original_request.temperature)
assert restored_request.max_tokens == original_request.max_tokens self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
assert len(restored_request.messages) == len(original_request.messages) self.assertEqual(len(restored_request.messages), len(original_request.messages))
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) unittest.main(verbosity=2)
# sglang/test/srt/openai/test_server.py # sglang/test/srt/openai/test_server.py
import pytest
import requests import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str): def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health") r = requests.get(f"{openai_server}/health")
assert r.status_code == 200, r.text assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == "" assert r.text == ""
@pytest.mark.xfail(reason="Endpoint skeleton not implemented yet")
def test_models_endpoint(openai_server: str): def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models") r = requests.get(f"{openai_server}/v1/models")
# once implemented this should be 200 assert r.status_code == 200, r.text
assert r.status_code == 200 payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404
""" """
Unit tests for the OpenAIServingChat class from serving_chat.py. Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
These tests ensure that the refactored implementation maintains compatibility python tests/test_serving_chat_unit.py -v
with the original adapter.py functionality. or
python -m unittest discover -s tests -p "test_*unit.py" -v
""" """
import unittest
import uuid import uuid
from typing import Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from fastapi import Request from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests class _MockTokenizerManager:
class MockTokenizerManager: """Minimal mock that satisfies OpenAIServingChat."""
def __init__(self): def __init__(self):
self.model_config = Mock() self.model_config = Mock(is_multimodal=False)
self.model_config.is_multimodal = False self.server_args = Mock(
self.server_args = Mock() enable_cache_report=False,
self.server_args.enable_cache_report = False tool_call_parser="hermes",
self.server_args.tool_call_parser = "hermes" reasoning_parser=None,
self.server_args.reasoning_parser = None )
self.chat_template_name = "llama-3" self.chat_template_name: Optional[str] = "llama-3"
# Mock tokenizer # tokenizer stub
self.tokenizer = Mock() self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
self.tokenizer.decode = Mock(return_value="Test response") self.tokenizer.decode.return_value = "Test response"
self.tokenizer.chat_template = None self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1 self.tokenizer.bos_token_id = 1
# Mock generate_request method # async generator stub for generate_request
async def mock_generate(): async def _mock_generate():
yield { yield {
"text": "Test response", "text": "Test response",
"meta_info": { "meta_info": {
...@@ -50,585 +53,176 @@ class MockTokenizerManager: ...@@ -50,585 +53,176 @@ class MockTokenizerManager:
"index": 0, "index": 0,
} }
self.generate_request = Mock(return_value=mock_generate()) self.generate_request = Mock(return_value=_mock_generate())
self.create_abort_task = Mock(return_value=None) self.create_abort_task = Mock()
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_chat(mock_tokenizer_manager):
"""Create a OpenAIServingChat instance for testing."""
return OpenAIServingChat(mock_tokenizer_manager)
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.chat = OpenAIServingChat(self.tm)
@pytest.fixture # frequently reused requests
def mock_request(): self.basic_req = ChatCompletionRequest(
"""Create a mock FastAPI request.""" model="x",
request = Mock(spec=Request) messages=[{"role": "user", "content": "Hi?"}],
request.headers = {}
return request
@pytest.fixture
def basic_chat_request():
"""Create a basic chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
stream=False, stream=False,
) )
self.stream_req = ChatCompletionRequest(
model="x",
@pytest.fixture messages=[{"role": "user", "content": "Hi?"}],
def streaming_chat_request():
"""Create a streaming chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
stream=True, stream=True,
) )
self.fastapi_request = Mock(spec=Request)
self.fastapi_request.headers = {}
class TestOpenAIServingChatConversion: # ------------- conversion tests -------------
"""Test request conversion methods.""" def test_convert_to_internal_request_single(self):
def test_convert_to_internal_request_single(
self, serving_chat, basic_chat_request, mock_tokenizer_manager
):
"""Test converting single request to internal format."""
with patch( with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as mock_conv: ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
mock_conv_instance = Mock() conv_ins = Mock()
mock_conv_instance.get_prompt.return_value = "Test prompt" conv_ins.get_prompt.return_value = "Test prompt"
mock_conv_instance.image_data = None conv_ins.image_data = conv_ins.audio_data = None
mock_conv_instance.audio_data = None conv_ins.modalities = []
mock_conv_instance.modalities = [] conv_ins.stop_str = ["</s>"]
mock_conv_instance.stop_str = ["</s>"] conv_mock.return_value = conv_ins
mock_conv.return_value = mock_conv_instance
proc_mock.return_value = (
# Mock the _process_messages method to return expected values
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt", "Test prompt",
[1, 2, 3], [1, 2, 3],
None, None,
None, None,
[], [],
["</s>"], ["</s>"],
None, # tool_call_constraint None,
) )
adapted_request, processed_request = ( adapted, processed = self.chat._convert_to_internal_request(
serving_chat._convert_to_internal_request( [self.basic_req], ["rid"]
[basic_chat_request], ["test-id"]
)
) )
self.assertIsInstance(adapted, GenerateReqInput)
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
assert isinstance(adapted_request, GenerateReqInput) # ------------- tool-call branch -------------
assert adapted_request.stream == basic_chat_request.stream def test_tool_call_request_conversion(self):
assert processed_request == basic_chat_request req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Weather?"}],
class TestToolCalls:
"""Test tool call functionality from adapter.py"""
def test_tool_call_request_conversion(self, serving_chat):
"""Test request with tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[ tools=[
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_weather", "name": "get_weather",
"description": "Get weather information", "parameters": {"type": "object", "properties": {}},
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}, },
} }
], ],
tool_choice="auto", tool_choice="auto",
) )
with patch.object(serving_chat, "_process_messages") as mock_process: with patch.object(
mock_process.return_value = ( self.chat,
"Test prompt", "_process_messages",
[1, 2, 3], return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
None, ):
None, adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
[], self.assertEqual(adapted.rid, "rid")
["</s>"],
None, # tool_call_constraint def test_tool_choice_none(self):
) req = ChatCompletionRequest(
model="x",
adapted_request, _ = serving_chat._convert_to_internal_request( messages=[{"role": "user", "content": "Hi"}],
[request], ["test-id"] tools=[{"type": "function", "function": {"name": "noop"}}],
)
assert adapted_request.rid == "test-id"
# Tool call constraint should be processed
assert request.tools is not None
def test_tool_choice_none(self, serving_chat):
"""Test tool_choice=none disables tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
tools=[{"type": "function", "function": {"name": "test_func"}}],
tool_choice="none", tool_choice="none",
) )
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
):
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
self.assertEqual(adapted.rid, "rid")
with patch.object(serving_chat, "_process_messages") as mock_process: # ------------- multimodal branch -------------
mock_process.return_value = ( def test_multimodal_request_with_images(self):
"Test prompt", self.tm.model_config.is_multimodal = True
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
# Tools should not be processed when tool_choice is "none"
assert adapted_request.rid == "test-id"
def test_tool_call_response_processing(self, serving_chat):
"""Test processing tool calls in response"""
mock_ret_item = {
"text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
finish_reason = {"type": "stop", "matched": None}
# Mock FunctionCallParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.has_tool_call.return_value = True
# Create proper mock tool call object
mock_tool_call = Mock()
mock_tool_call.name = "get_weather"
mock_tool_call.parameters = '{"location": "Paris"}'
mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
mock_parser_class.return_value = mock_parser
tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
mock_ret_item["text"], tools, "hermes", finish_reason
)
assert tool_calls is not None
assert len(tool_calls) == 1
assert updated_finish_reason["type"] == "tool_calls"
class TestMultimodalContent:
"""Test multimodal content handling from adapter.py"""
def test_multimodal_request_with_images(self, serving_chat): req = ChatCompletionRequest(
"""Test request with image content""" model="x",
request = ChatCompletionRequest(
model="test-model",
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "What's in this image?"}, {"type": "text", "text": "What's in the image?"},
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,..."}, "image_url": {"url": "data:image/jpeg;base64,"},
}, },
], ],
} }
], ],
) )
# Set multimodal mode
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
["image_data"],
None,
[],
[],
)
with patch.object( with patch.object(
serving_chat, "_apply_conversation_template" self.chat,
) as mock_conv: "_apply_jinja_template",
mock_conv.return_value = ("prompt", ["image_data"], None, [], []) return_value=("prompt", [1, 2], ["img"], None, [], []),
), patch.object(
( self.chat,
prompt, "_apply_conversation_template",
prompt_ids, return_value=("prompt", ["img"], None, [], []),
image_data, ):
audio_data, out = self.chat._process_messages(req, True)
modalities, _, _, image_data, *_ = out
stop, self.assertEqual(image_data, ["img"])
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert image_data == ["image_data"]
assert prompt == "prompt"
def test_multimodal_request_with_audio(self, serving_chat):
"""Test request with audio content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe this audio"},
{
"type": "audio_url",
"audio_url": {"url": "data:audio/wav;base64,UklGR..."},
},
],
}
],
)
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: # ------------- template handling -------------
mock_apply.return_value = ( def test_jinja_template_processing(self):
"prompt", req = ChatCompletionRequest(
[1, 2, 3], model="x", messages=[{"role": "user", "content": "Hello"}]
None,
["audio_data"],
["audio"],
[],
) )
self.tm.chat_template_name = None
self.tm.tokenizer.chat_template = "<jinja>"
with patch.object( with patch.object(
serving_chat, "_apply_conversation_template" self.chat,
) as mock_conv: "_apply_jinja_template",
mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) return_value=("processed", [1], None, None, [], ["</s>"]),
), patch("builtins.hasattr", return_value=True):
( prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
prompt, self.assertEqual(prompt, "processed")
prompt_ids, self.assertEqual(prompt_ids, [1])
image_data,
audio_data, # ------------- sampling-params -------------
modalities, def test_sampling_param_build(self):
stop, req = ChatCompletionRequest(
tool_call_constraint, model="x",
) = serving_chat._process_messages(request, True) messages=[{"role": "user", "content": "Hi"}],
assert audio_data == ["audio_data"]
assert modalities == ["audio"]
class TestTemplateHandling:
"""Test chat template handling from adapter.py"""
def test_jinja_template_processing(self, serving_chat):
"""Test Jinja template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
# Mock the template attribute directly
serving_chat.tokenizer_manager.chat_template_name = None
serving_chat.tokenizer_manager.tokenizer.chat_template = "<jinja_template>"
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"processed_prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
)
# Mock hasattr to simulate the None check
with patch("builtins.hasattr") as mock_hasattr:
mock_hasattr.return_value = True
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "processed_prompt"
assert prompt_ids == [1, 2, 3]
def test_conversation_template_processing(self, serving_chat):
"""Test conversation template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
serving_chat.tokenizer_manager.chat_template_name = "llama-3"
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("conv_prompt", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "conv_prompt"
assert stop == ["</s>"]
def test_continue_final_message(self, serving_chat):
"""Test continue_final_message functionality"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
],
continue_final_message=True,
)
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("Hi there", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
# Should handle continue_final_message properly
assert prompt == "Hi there"
class TestReasoningContent:
"""Test reasoning content separation from adapter.py"""
def test_reasoning_content_request(self, serving_chat):
"""Test request with reasoning content separation"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Solve this math problem"}],
separate_reasoning=True,
stream_reasoning=False,
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
assert request.separate_reasoning == True
def test_reasoning_content_response(self, serving_chat):
"""Test reasoning content in response"""
mock_ret_item = {
"text": "<thinking>This is reasoning</thinking>Answer: 42",
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
# Mock ReasoningParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.parse_non_stream.return_value = (
"This is reasoning",
"Answer: 42",
)
mock_parser_class.return_value = mock_parser
choice_logprobs = None
reasoning_text = None
text = mock_ret_item["text"]
# Simulate reasoning processing
enable_thinking = True
if enable_thinking:
parser = mock_parser_class(model_type="test", stream_reasoning=False)
reasoning_text, text = parser.parse_non_stream(text)
assert reasoning_text == "This is reasoning"
assert text == "Answer: 42"
class TestSamplingParams:
"""Test sampling parameter handling from adapter.py"""
def test_all_sampling_parameters(self, serving_chat):
"""Test all sampling parameters are properly handled"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.8, temperature=0.8,
max_tokens=150, max_tokens=150,
max_completion_tokens=200,
min_tokens=5, min_tokens=5,
top_p=0.9, top_p=0.9,
top_k=50, stop=["</s>"],
min_p=0.1,
presence_penalty=0.1,
frequency_penalty=0.2,
repetition_penalty=1.1,
stop=["<|endoftext|>"],
stop_token_ids=[13, 14],
regex=r"\d+",
ebnf="<expr> ::= <number>",
n=2,
no_stop_trim=True,
ignore_eos=True,
skip_special_tokens=False,
logit_bias={"1": 0.5, "2": -0.3},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
# Verify all parameters
assert sampling_params["temperature"] == 0.8
assert sampling_params["max_new_tokens"] == 150
assert sampling_params["min_new_tokens"] == 5
assert sampling_params["top_p"] == 0.9
assert sampling_params["top_k"] == 50
assert sampling_params["min_p"] == 0.1
assert sampling_params["presence_penalty"] == 0.1
assert sampling_params["frequency_penalty"] == 0.2
assert sampling_params["repetition_penalty"] == 1.1
assert sampling_params["stop"] == ["</s>"]
assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
def test_response_format_json_schema(self, serving_chat):
"""Test response format with JSON schema"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {"answer": {"type": "string"}},
},
},
},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert "json_schema" in sampling_params
assert '"type": "object"' in sampling_params["json_schema"]
def test_response_format_json_object(self, serving_chat):
"""Test response format with JSON object"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={"type": "json_object"},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
) )
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
):
params = self.chat._build_sampling_params(req, ["</s>"], None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert sampling_params["json_schema"] == '{"type": "object"}' if __name__ == "__main__":
unittest.main(verbosity=2)
""" """
Tests for the refactored completions serving handler Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
python -m unittest tests.test_serving_completions_unit -v
""" """
import unittest
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest from sglang.srt.entrypoints.openai.protocol import CompletionRequest
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
@pytest.fixture class ServingCompletionTestCase(unittest.TestCase):
def mock_tokenizer_manager(): """Bundle all prompt/echo tests in one TestCase."""
"""Create a mock tokenizer manager"""
manager = Mock(spec=TokenizerManager)
# Mock tokenizer
manager.tokenizer = Mock()
manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
manager.tokenizer.decode = Mock(return_value="decoded text")
manager.tokenizer.bos_token_id = 1
# Mock model config
manager.model_config = Mock()
manager.model_config.is_multimodal = False
# Mock server args
manager.server_args = Mock()
manager.server_args.enable_cache_report = False
# Mock generation # ---------- shared test fixtures ----------
manager.generate_request = AsyncMock() def setUp(self):
manager.create_abort_task = Mock(return_value=None) # build the mock TokenizerManager once for every test
tm = Mock(spec=TokenizerManager)
return manager tm.tokenizer = Mock()
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
tm.tokenizer.decode.return_value = "decoded text"
tm.tokenizer.bos_token_id = 1
tm.model_config = Mock(is_multimodal=False)
tm.server_args = Mock(enable_cache_report=False)
@pytest.fixture tm.generate_request = AsyncMock()
def serving_completion(mock_tokenizer_manager): tm.create_abort_task = Mock()
"""Create a OpenAIServingCompletion instance"""
return OpenAIServingCompletion(mock_tokenizer_manager)
self.sc = OpenAIServingCompletion(tm)
class TestPromptHandling: # ---------- prompt-handling ----------
"""Test different prompt types and formats from adapter.py""" def test_single_string_prompt(self):
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
def test_single_string_prompt(self, serving_completion): internal, _ = self.sc._convert_to_internal_request([req], ["id"])
"""Test handling single string prompt""" self.assertEqual(internal.text, "Hello world")
request = CompletionRequest(
model="test-model", prompt="Hello world", max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "Hello world" def test_single_token_ids_prompt(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_single_token_ids_prompt(self, serving_completion): def test_completion_template_handling(self):
"""Test handling single token IDs prompt""" req = CompletionRequest(
request = CompletionRequest( model="x", prompt="def f():", suffix="return 1", max_tokens=100
model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
) )
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.input_ids == [1, 2, 3, 4]
def test_completion_template_handling(self, serving_completion):
"""Test completion template processing"""
request = CompletionRequest(
model="test-model",
prompt="def hello():",
suffix="return 'world'",
max_tokens=100,
)
with patch( with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True, return_value=True,
): ), patch(
with patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt", return_value="processed_prompt",
): ):
adapted_request, _ = serving_completion._convert_to_internal_request( internal, _ = self.sc._convert_to_internal_request([req], ["id"])
[request], ["test-id"] self.assertEqual(internal.text, "processed_prompt")
)
assert adapted_request.text == "processed_prompt"
class TestEchoHandling:
"""Test echo functionality from adapter.py"""
def test_echo_with_string_prompt_streaming(self, serving_completion):
"""Test echo handling with string prompt in streaming"""
request = CompletionRequest(
model="test-model", prompt="Hello", max_tokens=100, echo=True
)
# Test _get_echo_text method # ---------- echo-handling ----------
echo_text = serving_completion._get_echo_text(request, 0) def test_echo_with_string_prompt_streaming(self):
assert echo_text == "Hello" req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
def test_echo_with_list_of_strings_streaming(self, serving_completion):
"""Test echo handling with list of strings in streaming"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "World"],
max_tokens=100,
echo=True,
n=1,
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
echo_text = serving_completion._get_echo_text(request, 1) def test_echo_with_list_of_strings_streaming(self):
assert echo_text == "World" req = CompletionRequest(
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
def test_echo_with_token_ids_streaming(self, serving_completion):
"""Test echo handling with token IDs in streaming"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
) )
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
serving_completion.tokenizer_manager.tokenizer.decode.return_value = ( def test_echo_with_token_ids_streaming(self):
"decoded_prompt" req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
echo_text = serving_completion._get_echo_text(request, 0) self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
assert echo_text == "decoded_prompt"
def test_echo_with_multiple_token_ids_streaming(self, serving_completion): def test_echo_with_multiple_token_ids_streaming(self):
"""Test echo handling with multiple token ID prompts in streaming""" req = CompletionRequest(
request = CompletionRequest( model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
) )
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" def test_prepare_echo_prompts_non_streaming(self):
echo_text = serving_completion._get_echo_text(request, 0) # single string
assert echo_text == "decoded" req = CompletionRequest(model="x", prompt="Hi", echo=True)
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
def test_prepare_echo_prompts_non_streaming(self, serving_completion): # list of strings
"""Test prepare echo prompts for non-streaming response""" req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
# Test with single string self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello"]
# Test with list of strings
request = CompletionRequest(
model="test-model", prompt=["Hello", "World"], echo=True
)
echo_prompts = serving_completion._prepare_echo_prompts(request) # token IDs
assert echo_prompts == ["Hello", "World"] req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
# Test with token IDs
request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" if __name__ == "__main__":
echo_prompts = serving_completion._prepare_echo_prompts(request) unittest.main(verbosity=2)
assert echo_prompts == ["decoded"]
...@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications ...@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
import asyncio import asyncio
import json import json
import time import time
import unittest
import uuid import uuid
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError from pydantic_core import ValidationError
...@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput ...@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests # Mock TokenizerManager for embedding tests
class MockTokenizerManager: class _MockTokenizerManager:
def __init__(self): def __init__(self):
self.model_config = Mock() self.model_config = Mock()
self.model_config.is_multimodal = False self.model_config.is_multimodal = False
...@@ -58,50 +58,26 @@ class MockTokenizerManager: ...@@ -58,50 +58,26 @@ class MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding()) self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture class ServingEmbeddingTestCase(unittest.TestCase):
def mock_tokenizer_manager(): def setUp(self):
"""Create a mock tokenizer manager for testing.""" """Set up test fixtures."""
return MockTokenizerManager() self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.request = Mock(spec=Request)
self.request.headers = {}
@pytest.fixture self.basic_req = EmbeddingRequest(
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model", model="test-model",
input="Hello, how are you?", input="Hello, how are you?",
encoding_format="float", encoding_format="float",
) )
self.list_req = EmbeddingRequest(
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"], input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float", encoding_format="float",
) )
self.multimodal_req = EmbeddingRequest(
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=[ input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
...@@ -109,90 +85,71 @@ def multimodal_embedding_request(): ...@@ -109,90 +85,71 @@ def multimodal_embedding_request():
], ],
encoding_format="float", encoding_format="float",
) )
self.token_ids_req = EmbeddingRequest(
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=[1, 2, 3, 4, 5], input=[1, 2, 3, 4, 5],
encoding_format="float", encoding_format="float",
) )
def test_convert_single_string_request(self):
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
"""Test converting single string request to internal format.""" """Test converting single string request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"] [self.basic_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?" self.assertEqual(adapted_request.text, "Hello, how are you?")
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
assert processed_request == basic_embedding_request self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request( def test_convert_list_string_request(self):
self, serving_embedding, list_embedding_request
):
"""Test converting list of strings request to internal format.""" """Test converting list of strings request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"] [self.list_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"] self.assertEqual(
assert adapted_request.rid == "test-id" adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
assert processed_request == list_embedding_request )
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request( def test_convert_token_ids_request(self):
self, serving_embedding, token_ids_embedding_request
):
"""Test converting token IDs request to internal format.""" """Test converting token IDs request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"] [self.token_ids_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5] self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
assert processed_request == token_ids_embedding_request self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request( def test_convert_multimodal_request(self):
self, serving_embedding, multimodal_embedding_request
):
"""Test converting multimodal request to internal format.""" """Test converting multimodal request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"] [self.multimodal_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately # Should extract text and images separately
assert len(adapted_request.text) == 2 self.assertEqual(len(adapted_request.text), 2)
assert "Hello" in adapted_request.text self.assertIn("Hello", adapted_request.text)
assert "World" in adapted_request.text self.assertIn("World", adapted_request.text)
assert adapted_request.image_data[0] == "base64_image_data" self.assertEqual(adapted_request.image_data[0], "base64_image_data")
assert adapted_request.image_data[1] is None self.assertIsNone(adapted_request.image_data[1])
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self):
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
"""Test building response for single embedding.""" """Test building response for single embedding."""
ret_data = [ ret_data = [
{ {
...@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding: ...@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
} }
] ]
response = serving_embedding._build_embedding_response(ret_data, "test-model") response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
assert isinstance(response, EmbeddingResponse) )
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
def test_build_multiple_embedding_response(self, serving_embedding): self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings.""" """Test building response for multiple embeddings."""
ret_data = [ ret_data = [
{ {
...@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding: ...@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
}, },
] ]
response = serving_embedding._build_embedding_response(ret_data, "test-model") response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
assert isinstance(response, EmbeddingResponse) )
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
@pytest.mark.asyncio self.assertIsInstance(response, EmbeddingResponse)
class TestOpenAIServingEmbeddingAsyncMethods: self.assertEqual(len(response.data), 2)
"""Test async methods of OpenAIServingEmbedding.""" self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
async def test_handle_request_success( async def test_handle_request_success(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test successful embedding request handling.""" """Test successful embedding request handling."""
# Mock the generate_request to return expected data # Mock the generate_request to return expected data
...@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods: ...@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
"meta_info": {"prompt_tokens": 5}, "meta_info": {"prompt_tokens": 5},
} }
serving_embedding.tokenizer_manager.generate_request = Mock( self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate() return_value=mock_generate()
) )
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, EmbeddingResponse) self.assertIsInstance(response, EmbeddingResponse)
assert len(response.data) == 1 self.assertEqual(len(response.data), 1)
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
async def test_handle_request_validation_error( async def test_handle_request_validation_error(self):
self, serving_embedding, mock_request
):
"""Test handling request with validation error.""" """Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="") invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request) response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 400 self.assertEqual(response.status_code, 400)
async def test_handle_request_generation_error( async def test_handle_request_generation_error(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with generation error.""" """Test handling request with generation error."""
# Mock generate_request to raise an error # Mock generate_request to raise an error
...@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods: ...@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
raise ValueError("Generation failed") raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock( self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error() return_value=mock_generate_error()
) )
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 400 self.assertEqual(response.status_code, 400)
async def test_handle_request_internal_error( async def test_handle_request_internal_error(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with internal server error.""" """Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception # Mock _convert_to_internal_request to raise an exception
with patch.object( with patch.object(
serving_embedding, self.serving_embedding,
"_convert_to_internal_request", "_convert_to_internal_request",
side_effect=Exception("Internal error"), side_effect=Exception("Internal error"),
): ):
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 500 self.assertEqual(response.status_code, 500)
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -62,6 +62,11 @@ suites = { ...@@ -62,6 +62,11 @@ suites = {
TestFile("test_openai_adapter.py", 1), TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149), TestFile("test_openai_server.py", 149),
TestFile("openai/test_server.py", 120),
TestFile("openai/test_protocol.py", 60),
TestFile("openai/test_serving_chat.py", 120),
TestFile("openai/test_serving_completions.py", 120),
TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240), TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment