Unverified Commit ffd1a26e authored by Jinn's avatar Jinn Committed by GitHub
Browse files

Add more refactored openai test & in CI (#7284)

parent 09ae5b20
...@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response from fastapi.responses import Response
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
FakeBootstrapHost, FAKE_BOOTSTRAP_HOST,
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
...@@ -265,7 +265,7 @@ def _wait_and_warmup( ...@@ -265,7 +265,7 @@ def _wait_and_warmup(
"max_new_tokens": 8, "max_new_tokens": 8,
"ignore_eos": True, "ignore_eos": True,
}, },
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup # This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup # ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [ "bootstrap_room": [
......
...@@ -12,9 +12,10 @@ import pytest ...@@ -12,9 +12,10 @@ import pytest
import requests import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server" SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = "dummy-model" DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120)) STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
...@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No ...@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
def launch_openai_server(model: str = DEFAULT_MODEL, **kw): def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until its ready.""" """Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port() port = _pick_free_port()
cmd = [ cmd = [
sys.executable, sys.executable,
...@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw): ...@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]: def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the servers base URL and cleans up.""" """PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server() proc, base, log_file = launch_openai_server()
yield base yield base
kill_process_tree(proc.pid) kill_process_tree(proc.pid)
......
This diff is collapsed.
# sglang/test/srt/openai/test_server.py # sglang/test/srt/openai/test_server.py
import pytest
import requests import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str): def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health") r = requests.get(f"{openai_server}/health")
assert r.status_code == 200, r.text assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == "" assert r.text == ""
@pytest.mark.xfail(reason="Endpoint skeleton not implemented yet")
def test_models_endpoint(openai_server: str): def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models") r = requests.get(f"{openai_server}/v1/models")
# once implemented this should be 200 assert r.status_code == 200, r.text
assert r.status_code == 200 payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404
This diff is collapsed.
""" """
Tests for the refactored completions serving handler Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
python -m unittest tests.test_serving_completions_unit -v
""" """
import unittest
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest from sglang.srt.entrypoints.openai.protocol import CompletionRequest
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
@pytest.fixture class ServingCompletionTestCase(unittest.TestCase):
def mock_tokenizer_manager(): """Bundle all prompt/echo tests in one TestCase."""
"""Create a mock tokenizer manager"""
manager = Mock(spec=TokenizerManager)
# Mock tokenizer
manager.tokenizer = Mock()
manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
manager.tokenizer.decode = Mock(return_value="decoded text")
manager.tokenizer.bos_token_id = 1
# Mock model config
manager.model_config = Mock()
manager.model_config.is_multimodal = False
# Mock server args
manager.server_args = Mock()
manager.server_args.enable_cache_report = False
# Mock generation # ---------- shared test fixtures ----------
manager.generate_request = AsyncMock() def setUp(self):
manager.create_abort_task = Mock(return_value=None) # build the mock TokenizerManager once for every test
tm = Mock(spec=TokenizerManager)
return manager tm.tokenizer = Mock()
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
tm.tokenizer.decode.return_value = "decoded text"
tm.tokenizer.bos_token_id = 1
tm.model_config = Mock(is_multimodal=False)
tm.server_args = Mock(enable_cache_report=False)
@pytest.fixture tm.generate_request = AsyncMock()
def serving_completion(mock_tokenizer_manager): tm.create_abort_task = Mock()
"""Create a OpenAIServingCompletion instance"""
return OpenAIServingCompletion(mock_tokenizer_manager)
self.sc = OpenAIServingCompletion(tm)
class TestPromptHandling: # ---------- prompt-handling ----------
"""Test different prompt types and formats from adapter.py""" def test_single_string_prompt(self):
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
def test_single_string_prompt(self, serving_completion): internal, _ = self.sc._convert_to_internal_request([req], ["id"])
"""Test handling single string prompt""" self.assertEqual(internal.text, "Hello world")
request = CompletionRequest(
model="test-model", prompt="Hello world", max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "Hello world" def test_single_token_ids_prompt(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_single_token_ids_prompt(self, serving_completion): def test_completion_template_handling(self):
"""Test handling single token IDs prompt""" req = CompletionRequest(
request = CompletionRequest( model="x", prompt="def f():", suffix="return 1", max_tokens=100
model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
) )
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.input_ids == [1, 2, 3, 4]
def test_completion_template_handling(self, serving_completion):
"""Test completion template processing"""
request = CompletionRequest(
model="test-model",
prompt="def hello():",
suffix="return 'world'",
max_tokens=100,
)
with patch( with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True, return_value=True,
): ), patch(
with patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt", return_value="processed_prompt",
): ):
adapted_request, _ = serving_completion._convert_to_internal_request( internal, _ = self.sc._convert_to_internal_request([req], ["id"])
[request], ["test-id"] self.assertEqual(internal.text, "processed_prompt")
)
assert adapted_request.text == "processed_prompt"
class TestEchoHandling:
"""Test echo functionality from adapter.py"""
def test_echo_with_string_prompt_streaming(self, serving_completion):
"""Test echo handling with string prompt in streaming"""
request = CompletionRequest(
model="test-model", prompt="Hello", max_tokens=100, echo=True
)
# Test _get_echo_text method # ---------- echo-handling ----------
echo_text = serving_completion._get_echo_text(request, 0) def test_echo_with_string_prompt_streaming(self):
assert echo_text == "Hello" req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
def test_echo_with_list_of_strings_streaming(self, serving_completion):
"""Test echo handling with list of strings in streaming"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "World"],
max_tokens=100,
echo=True,
n=1,
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
echo_text = serving_completion._get_echo_text(request, 1) def test_echo_with_list_of_strings_streaming(self):
assert echo_text == "World" req = CompletionRequest(
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
def test_echo_with_token_ids_streaming(self, serving_completion):
"""Test echo handling with token IDs in streaming"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
) )
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
serving_completion.tokenizer_manager.tokenizer.decode.return_value = ( def test_echo_with_token_ids_streaming(self):
"decoded_prompt" req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
echo_text = serving_completion._get_echo_text(request, 0) self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
assert echo_text == "decoded_prompt"
def test_echo_with_multiple_token_ids_streaming(self, serving_completion): def test_echo_with_multiple_token_ids_streaming(self):
"""Test echo handling with multiple token ID prompts in streaming""" req = CompletionRequest(
request = CompletionRequest( model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
) )
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" def test_prepare_echo_prompts_non_streaming(self):
echo_text = serving_completion._get_echo_text(request, 0) # single string
assert echo_text == "decoded" req = CompletionRequest(model="x", prompt="Hi", echo=True)
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
def test_prepare_echo_prompts_non_streaming(self, serving_completion): # list of strings
"""Test prepare echo prompts for non-streaming response""" req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
# Test with single string self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello"]
# Test with list of strings
request = CompletionRequest(
model="test-model", prompt=["Hello", "World"], echo=True
)
echo_prompts = serving_completion._prepare_echo_prompts(request) # token IDs
assert echo_prompts == ["Hello", "World"] req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
# Test with token IDs
request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" if __name__ == "__main__":
echo_prompts = serving_completion._prepare_echo_prompts(request) unittest.main(verbosity=2)
assert echo_prompts == ["decoded"]
...@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications ...@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
import asyncio import asyncio
import json import json
import time import time
import unittest
import uuid import uuid
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError from pydantic_core import ValidationError
...@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput ...@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests # Mock TokenizerManager for embedding tests
class MockTokenizerManager: class _MockTokenizerManager:
def __init__(self): def __init__(self):
self.model_config = Mock() self.model_config = Mock()
self.model_config.is_multimodal = False self.model_config.is_multimodal = False
...@@ -58,50 +58,26 @@ class MockTokenizerManager: ...@@ -58,50 +58,26 @@ class MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding()) self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture class ServingEmbeddingTestCase(unittest.TestCase):
def mock_tokenizer_manager(): def setUp(self):
"""Create a mock tokenizer manager for testing.""" """Set up test fixtures."""
return MockTokenizerManager() self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.request = Mock(spec=Request)
self.request.headers = {}
@pytest.fixture self.basic_req = EmbeddingRequest(
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model", model="test-model",
input="Hello, how are you?", input="Hello, how are you?",
encoding_format="float", encoding_format="float",
) )
self.list_req = EmbeddingRequest(
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"], input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float", encoding_format="float",
) )
self.multimodal_req = EmbeddingRequest(
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=[ input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
...@@ -109,90 +85,71 @@ def multimodal_embedding_request(): ...@@ -109,90 +85,71 @@ def multimodal_embedding_request():
], ],
encoding_format="float", encoding_format="float",
) )
self.token_ids_req = EmbeddingRequest(
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model", model="test-model",
input=[1, 2, 3, 4, 5], input=[1, 2, 3, 4, 5],
encoding_format="float", encoding_format="float",
) )
def test_convert_single_string_request(self):
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
"""Test converting single string request to internal format.""" """Test converting single string request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"] [self.basic_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?" self.assertEqual(adapted_request.text, "Hello, how are you?")
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
assert processed_request == basic_embedding_request self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request( def test_convert_list_string_request(self):
self, serving_embedding, list_embedding_request
):
"""Test converting list of strings request to internal format.""" """Test converting list of strings request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"] [self.list_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"] self.assertEqual(
assert adapted_request.rid == "test-id" adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
assert processed_request == list_embedding_request )
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request( def test_convert_token_ids_request(self):
self, serving_embedding, token_ids_embedding_request
):
"""Test converting token IDs request to internal format.""" """Test converting token IDs request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"] [self.token_ids_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5] self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
assert processed_request == token_ids_embedding_request self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request( def test_convert_multimodal_request(self):
self, serving_embedding, multimodal_embedding_request
):
"""Test converting multimodal request to internal format.""" """Test converting multimodal request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"] [self.multimodal_req], ["test-id"]
) )
) )
assert isinstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately # Should extract text and images separately
assert len(adapted_request.text) == 2 self.assertEqual(len(adapted_request.text), 2)
assert "Hello" in adapted_request.text self.assertIn("Hello", adapted_request.text)
assert "World" in adapted_request.text self.assertIn("World", adapted_request.text)
assert adapted_request.image_data[0] == "base64_image_data" self.assertEqual(adapted_request.image_data[0], "base64_image_data")
assert adapted_request.image_data[1] is None self.assertIsNone(adapted_request.image_data[1])
assert adapted_request.rid == "test-id" self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self):
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
"""Test building response for single embedding.""" """Test building response for single embedding."""
ret_data = [ ret_data = [
{ {
...@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding: ...@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
} }
] ]
response = serving_embedding._build_embedding_response(ret_data, "test-model") response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
assert isinstance(response, EmbeddingResponse) )
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
def test_build_multiple_embedding_response(self, serving_embedding): self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings.""" """Test building response for multiple embeddings."""
ret_data = [ ret_data = [
{ {
...@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding: ...@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
}, },
] ]
response = serving_embedding._build_embedding_response(ret_data, "test-model") response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
assert isinstance(response, EmbeddingResponse) )
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
@pytest.mark.asyncio self.assertIsInstance(response, EmbeddingResponse)
class TestOpenAIServingEmbeddingAsyncMethods: self.assertEqual(len(response.data), 2)
"""Test async methods of OpenAIServingEmbedding.""" self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
async def test_handle_request_success( async def test_handle_request_success(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test successful embedding request handling.""" """Test successful embedding request handling."""
# Mock the generate_request to return expected data # Mock the generate_request to return expected data
...@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods: ...@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
"meta_info": {"prompt_tokens": 5}, "meta_info": {"prompt_tokens": 5},
} }
serving_embedding.tokenizer_manager.generate_request = Mock( self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate() return_value=mock_generate()
) )
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, EmbeddingResponse) self.assertIsInstance(response, EmbeddingResponse)
assert len(response.data) == 1 self.assertEqual(len(response.data), 1)
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
async def test_handle_request_validation_error( async def test_handle_request_validation_error(self):
self, serving_embedding, mock_request
):
"""Test handling request with validation error.""" """Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="") invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request) response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 400 self.assertEqual(response.status_code, 400)
async def test_handle_request_generation_error( async def test_handle_request_generation_error(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with generation error.""" """Test handling request with generation error."""
# Mock generate_request to raise an error # Mock generate_request to raise an error
...@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods: ...@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
raise ValueError("Generation failed") raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock( self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error() return_value=mock_generate_error()
) )
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 400 self.assertEqual(response.status_code, 400)
async def test_handle_request_internal_error( async def test_handle_request_internal_error(self):
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with internal server error.""" """Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception # Mock _convert_to_internal_request to raise an exception
with patch.object( with patch.object(
serving_embedding, self.serving_embedding,
"_convert_to_internal_request", "_convert_to_internal_request",
side_effect=Exception("Internal error"), side_effect=Exception("Internal error"),
): ):
response = await serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
basic_embedding_request, mock_request self.basic_req, self.request
) )
assert isinstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
assert response.status_code == 500 self.assertEqual(response.status_code, 500)
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -62,6 +62,11 @@ suites = { ...@@ -62,6 +62,11 @@ suites = {
TestFile("test_openai_adapter.py", 1), TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149), TestFile("test_openai_server.py", 149),
TestFile("openai/test_server.py", 120),
TestFile("openai/test_protocol.py", 60),
TestFile("openai/test_serving_chat.py", 120),
TestFile("openai/test_serving_completions.py", 120),
TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240), TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment