Unverified Commit e321c971 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Add comprehensive E2E tests for Response API (#11988)

parent d6fee73d
......@@ -144,12 +144,6 @@ jobs:
python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2
pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
- name: Run Python E2E gRPC tests
run: |
bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router
SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO
- name: Upload benchmark results
if: success()
uses: actions/upload-artifact@v4
......@@ -157,8 +151,58 @@ jobs:
name: genai-bench-results-all-policies
path: sgl-router/benchmark_**/
pytest-rust-2:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: 4-gpu-a10
timeout-minutes: 16
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install rust dependencies
run: |
bash scripts/ci/ci_install_rust.sh
- name: Configure sccache
uses: mozilla-actions/sccache-action@v0.0.9
with:
version: "v0.10.0"
- name: Rust cache
uses: Swatinem/rust-cache@v2
with:
workspaces: sgl-router
cache-all-crates: true
cache-on-failure: true
- name: Install SGLang dependencies
run: |
sudo --preserve-env=PATH bash scripts/ci/ci_install_dependency.sh
- name: Build python binding
run: |
source "$HOME/.cargo/env"
export RUSTC_WRAPPER=sccache
cd sgl-router
pip install setuptools-rust wheel build
python3 -m build
pip install --force-reinstall dist/*.whl
- name: Run Python E2E response API tests
run: |
bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router
SHOW_ROUTER_LOGS=1 pytest py_test/e2e_response_api -s -vv -o log_cli=true --log-cli-level=INFO
- name: Run Python E2E gRPC tests
run: |
bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router
SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO
finish:
needs: [unit-test-rust, pytest-rust]
needs: [unit-test-rust, pytest-rust, pytest-rust-2]
runs-on: ubuntu-latest
steps:
- name: Finish
......
......@@ -267,8 +267,6 @@ def popen_launch_workers_and_router(
policy,
"--model-path",
model,
"--log-level",
"warn",
]
# Add worker URLs
......
"""
Base test class for Response API e2e tests.
This module provides base test classes that can be reused across different backends
(OpenAI, XAI, gRPC) with common test logic.
"""
import json
import sys
import time
import unittest
from pathlib import Path
from typing import Optional
import requests
# Add current directory for local imports
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR))
from util import CustomTestCase
class ResponseAPIBaseTest(CustomTestCase):
"""Base class for Response API tests with common utilities."""
# To be set by subclasses
base_url: str = None
api_key: str = None
model: str = None
def make_request(
self,
endpoint: str,
method: str = "POST",
json_data: Optional[dict] = None,
params: Optional[dict] = None,
) -> requests.Response:
"""
Make HTTP request to router.
Args:
endpoint: Endpoint path (e.g., "/v1/responses")
method: HTTP method (GET, POST, DELETE)
json_data: JSON body for POST requests
params: Query parameters
Returns:
requests.Response object
"""
url = f"{self.base_url}{endpoint}"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
if method == "POST":
resp = requests.post(url, json=json_data, headers=headers, params=params)
elif method == "GET":
resp = requests.get(url, headers=headers, params=params)
elif method == "DELETE":
resp = requests.delete(url, headers=headers, params=params)
else:
raise ValueError(f"Unsupported method: {method}")
return resp
def create_response(
self,
input_text: str,
instructions: Optional[str] = None,
stream: bool = False,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
previous_response_id: Optional[str] = None,
conversation: Optional[str] = None,
tools: Optional[list] = None,
background: bool = False,
**kwargs,
) -> requests.Response:
"""
Create a response via POST /v1/responses.
Args:
input_text: User input
instructions: Optional system instructions
stream: Whether to stream response
max_output_tokens: Optional max tokens to generate
temperature: Sampling temperature
previous_response_id: Optional previous response ID for state management
conversation: Optional conversation ID for state management
tools: Optional list of MCP tools
background: Whether to run in background mode
**kwargs: Additional request parameters
Returns:
requests.Response object
"""
data = {
"model": self.model,
"input": input_text,
"stream": stream,
**kwargs,
}
if instructions:
data["instructions"] = instructions
if max_output_tokens is not None:
data["max_output_tokens"] = max_output_tokens
if temperature is not None:
data["temperature"] = temperature
if previous_response_id:
data["previous_response_id"] = previous_response_id
if conversation:
data["conversation"] = conversation
if tools:
data["tools"] = tools
if background:
data["background"] = background
if stream:
# For streaming, we need to handle SSE
return self._create_streaming_response(data)
else:
return self.make_request("/v1/responses", "POST", data)
def _create_streaming_response(self, data: dict) -> requests.Response:
"""Handle streaming response creation."""
url = f"{self.base_url}/v1/responses"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
# Return response object with stream=True
return requests.post(url, json=data, headers=headers, stream=True)
def get_response(self, response_id: str) -> requests.Response:
"""Get response by ID via GET /v1/responses/{response_id}."""
return self.make_request(f"/v1/responses/{response_id}", "GET")
def delete_response(self, response_id: str) -> requests.Response:
"""Delete response by ID via DELETE /v1/responses/{response_id}."""
return self.make_request(f"/v1/responses/{response_id}", "DELETE")
def cancel_response(self, response_id: str) -> requests.Response:
"""Cancel response by ID via POST /v1/responses/{response_id}/cancel."""
return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {})
def get_response_input(self, response_id: str) -> requests.Response:
"""Get response input items via GET /v1/responses/{response_id}/input."""
return self.make_request(f"/v1/responses/{response_id}/input", "GET")
def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response:
"""Create conversation via POST /v1/conversations."""
data = {}
if metadata:
data["metadata"] = metadata
return self.make_request("/v1/conversations", "POST", data)
def get_conversation(self, conversation_id: str) -> requests.Response:
"""Get conversation by ID via GET /v1/conversations/{conversation_id}."""
return self.make_request(f"/v1/conversations/{conversation_id}", "GET")
def update_conversation(
self, conversation_id: str, metadata: dict
) -> requests.Response:
"""Update conversation via POST /v1/conversations/{conversation_id}."""
return self.make_request(
f"/v1/conversations/{conversation_id}", "POST", {"metadata": metadata}
)
def delete_conversation(self, conversation_id: str) -> requests.Response:
"""Delete conversation via DELETE /v1/conversations/{conversation_id}."""
return self.make_request(f"/v1/conversations/{conversation_id}", "DELETE")
def list_conversation_items(
self,
conversation_id: str,
limit: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
order: str = "asc",
) -> requests.Response:
"""List conversation items via GET /v1/conversations/{conversation_id}/items."""
params = {"order": order}
if limit:
params["limit"] = limit
if after:
params["after"] = after
if before:
params["before"] = before
return self.make_request(
f"/v1/conversations/{conversation_id}/items", "GET", params=params
)
def create_conversation_items(
self, conversation_id: str, items: list
) -> requests.Response:
"""Create conversation items via POST /v1/conversations/{conversation_id}/items."""
return self.make_request(
f"/v1/conversations/{conversation_id}/items", "POST", {"items": items}
)
def get_conversation_item(
self, conversation_id: str, item_id: str
) -> requests.Response:
"""Get conversation item via GET /v1/conversations/{conversation_id}/items/{item_id}."""
return self.make_request(
f"/v1/conversations/{conversation_id}/items/{item_id}", "GET"
)
def delete_conversation_item(
self, conversation_id: str, item_id: str
) -> requests.Response:
"""Delete conversation item via DELETE /v1/conversations/{conversation_id}/items/{item_id}."""
return self.make_request(
f"/v1/conversations/{conversation_id}/items/{item_id}", "DELETE"
)
def parse_sse_events(self, response: requests.Response) -> list:
"""
Parse Server-Sent Events from streaming response.
Args:
response: requests.Response with stream=True
Returns:
List of event dictionaries with 'event' and 'data' keys
"""
events = []
current_event = None
for line in response.iter_lines():
if not line:
# Empty line signals end of event
if current_event and current_event.get("data"):
events.append(current_event)
current_event = None
continue
line = line.decode("utf-8")
if line.startswith("event:"):
current_event = {"event": line[6:].strip()}
elif line.startswith("data:"):
if current_event is None:
current_event = {}
data_str = line[5:].strip()
try:
current_event["data"] = json.loads(data_str)
except json.JSONDecodeError:
current_event["data"] = data_str
# Don't forget the last event if stream ends without empty line
if current_event and current_event.get("data"):
events.append(current_event)
return events
def wait_for_background_task(
self, response_id: str, timeout: int = 30, poll_interval: float = 0.5
) -> dict:
"""
Wait for background task to complete.
Args:
response_id: Response ID to poll
timeout: Max seconds to wait
poll_interval: Seconds between polls
Returns:
Final response data
Raises:
TimeoutError: If task doesn't complete in time
AssertionError: If task fails
"""
start_time = time.time()
while time.time() - start_time < timeout:
resp = self.get_response(response_id)
self.assertEqual(resp.status_code, 200)
data = resp.json()
status = data.get("status")
if status == "completed":
return data
elif status == "failed":
raise AssertionError(
f"Background task failed: {data.get('error', 'Unknown error')}"
)
elif status == "cancelled":
raise AssertionError("Background task was cancelled")
time.sleep(poll_interval)
raise TimeoutError(
f"Background task {response_id} did not complete within {timeout}s"
)
class StateManagementBaseTest(ResponseAPIBaseTest):
"""Base class for state management tests (previous_response_id and conversation)."""
def test_basic_response_creation(self):
"""Test basic response creation without state."""
resp = self.create_response("What is 2+2?", max_output_tokens=50)
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertIn("id", data)
self.assertIn("output", data)
self.assertEqual(data["status"], "completed")
self.assertIn("usage", data)
def test_streaming_response(self):
"""Test streaming response."""
resp = self.create_response("Count to 5", stream=True, max_output_tokens=50)
self.assertEqual(resp.status_code, 200)
events = self.parse_sse_events(resp)
self.assertGreater(len(events), 0)
# Check for response.created event
created_events = [e for e in events if e.get("event") == "response.created"]
self.assertGreater(len(created_events), 0)
# Check for final completed event or in_progress events
self.assertTrue(
any(
e.get("event") in ["response.completed", "response.in_progress"]
for e in events
)
)
class ResponseCRUDBaseTest(ResponseAPIBaseTest):
"""Base class for Response API CRUD tests."""
def test_create_and_get_response(self):
"""Test creating response and retrieving it."""
# Create response
create_resp = self.create_response("Hello, world!")
self.assertEqual(create_resp.status_code, 200)
create_data = create_resp.json()
response_id = create_data["id"]
# Get response
get_resp = self.get_response(response_id)
self.assertEqual(get_resp.status_code, 200)
get_data = get_resp.json()
self.assertEqual(get_data["id"], response_id)
self.assertEqual(get_data["status"], "completed")
input_resp = self.get_response_input(get_data["id"])
# change not merge yet
self.assertEqual(input_resp.status_code, 501)
# self.assertEqual(input_resp.status_code, 200)
# input_data = input_resp.json()
# self.assertIn("data", input_data)
# self.assertGreater(len(input_data["data"]), 0)
@unittest.skip("TODO: Add delete response feature")
def test_delete_response(self):
"""Test deleting response."""
# Create response
create_resp = self.create_response("Test deletion", max_output_tokens=50)
self.assertEqual(create_resp.status_code, 200)
response_id = create_resp.json()["id"]
# Delete response
delete_resp = self.delete_response(response_id)
self.assertEqual(delete_resp.status_code, 200)
# Verify it's deleted (should return 404)
get_resp = self.get_response(response_id)
self.assertEqual(get_resp.status_code, 404)
@unittest.skip("TODO: Add background response feature")
def test_background_response(self):
"""Test background response execution."""
# Create background response
create_resp = self.create_response(
"Write a short story", background=True, max_output_tokens=100
)
self.assertEqual(create_resp.status_code, 200)
create_data = create_resp.json()
response_id = create_data["id"]
self.assertEqual(create_data["status"], "in_progress")
# Wait for completion
final_data = self.wait_for_background_task(response_id, timeout=60)
self.assertEqual(final_data["status"], "completed")
class ConversationCRUDBaseTest(ResponseAPIBaseTest):
"""Base class for Conversation API CRUD tests."""
def test_create_and_get_conversation(self):
"""Test creating and retrieving conversation."""
# Create conversation
create_resp = self.create_conversation(metadata={"user": "test_user"})
self.assertEqual(create_resp.status_code, 200)
create_data = create_resp.json()
conversation_id = create_data["id"]
self.assertEqual(create_data["metadata"]["user"], "test_user")
# Get conversation
get_resp = self.get_conversation(conversation_id)
self.assertEqual(get_resp.status_code, 200)
get_data = get_resp.json()
self.assertEqual(get_data["id"], conversation_id)
self.assertEqual(get_data["metadata"]["user"], "test_user")
def test_update_conversation(self):
"""Test updating conversation metadata."""
# Create conversation
create_resp = self.create_conversation(metadata={"key1": "value1"})
self.assertEqual(create_resp.status_code, 200)
conversation_id = create_resp.json()["id"]
# Update conversation
update_resp = self.update_conversation(
conversation_id, metadata={"key1": "value1", "key2": "value2"}
)
self.assertEqual(update_resp.status_code, 200)
# Verify update
get_resp = self.get_conversation(conversation_id)
get_data = get_resp.json()
self.assertEqual(get_data["metadata"]["key2"], "value2")
def test_delete_conversation(self):
"""Test deleting conversation."""
# Create conversation
create_resp = self.create_conversation()
self.assertEqual(create_resp.status_code, 200)
conversation_id = create_resp.json()["id"]
# Delete conversation
delete_resp = self.delete_conversation(conversation_id)
self.assertEqual(delete_resp.status_code, 200)
# Verify deletion
get_resp = self.get_conversation(conversation_id)
self.assertEqual(get_resp.status_code, 404)
def test_list_conversation_items(self):
"""Test listing conversation items."""
# Create conversation
conv_resp = self.create_conversation()
conversation_id = conv_resp.json()["id"]
# Create response with conversation
self.create_response(
"First message", conversation=conversation_id, max_output_tokens=50
)
self.create_response(
"Second message", conversation=conversation_id, max_output_tokens=50
)
# List items
list_resp = self.list_conversation_items(conversation_id)
self.assertEqual(list_resp.status_code, 200)
list_data = list_resp.json()
self.assertIn("data", list_data)
# Should have at least 4 items (2 inputs + 2 outputs)
self.assertGreaterEqual(len(list_data["data"]), 4)
"""
pytest configuration for e2e_response_api tests.
This configures pytest to not collect base test classes that are meant to be inherited.
"""
import pytest
def pytest_collection_modifyitems(config, items):
"""
Modify test collection to exclude base test classes.
Base test classes are meant to be inherited, not run directly.
We exclude any test that comes from these base classes:
- StateManagementBaseTest
- ResponseCRUDBaseTest
- ConversationCRUDBaseTest
- MCPTests
- StateManagementTests
"""
base_class_names = {
"StateManagementBaseTest",
"ResponseCRUDBaseTest",
"ConversationCRUDBaseTest",
"MCPTests",
"StateManagementTests",
}
# Filter out tests from base classes
filtered_items = []
for item in items:
# Check if the test's parent class is a base class
parent_name = item.parent.name if hasattr(item, "parent") else None
if parent_name not in base_class_names:
filtered_items.append(item)
# Update items list
items[:] = filtered_items
"""
MCP (Model Context Protocol) tests for Response API.
Tests MCP tool calling in both streaming and non-streaming modes.
These tests should work across all backends that support MCP (OpenAI, XAI).
"""
from base import ResponseAPIBaseTest
class MCPTests(ResponseAPIBaseTest):
"""Tests for MCP tool calling in both streaming and non-streaming modes."""
def test_mcp_basic_tool_call(self):
"""Test basic MCP tool call (non-streaming)."""
tools = [
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": "never",
}
]
resp = self.create_response(
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
tools=tools,
stream=False,
)
# Should successfully make the request
self.assertEqual(resp.status_code, 200)
data = resp.json()
print(f"MCP response: {data}")
# Basic response structure
self.assertIn("id", data)
self.assertIn("status", data)
self.assertEqual(data["status"], "completed")
self.assertIn("output", data)
self.assertIn("model", data)
# Verify output array is not empty
output = data["output"]
self.assertIsInstance(output, list)
self.assertGreater(len(output), 0)
# Check for MCP-specific output types
output_types = [item.get("type") for item in output]
# Should have mcp_list_tools - tools are listed before calling
self.assertIn(
"mcp_list_tools", output_types, "Response should contain mcp_list_tools"
)
# Should have at least one mcp_call
mcp_calls = [item for item in output if item.get("type") == "mcp_call"]
self.assertGreater(
len(mcp_calls), 0, "Response should contain at least one mcp_call"
)
# Verify mcp_call structure
for mcp_call in mcp_calls:
self.assertIn("id", mcp_call)
self.assertIn("status", mcp_call)
self.assertEqual(mcp_call["status"], "completed")
self.assertIn("server_label", mcp_call)
self.assertEqual(mcp_call["server_label"], "deepwiki")
self.assertIn("name", mcp_call)
self.assertIn("arguments", mcp_call)
self.assertIn("output", mcp_call)
# Should have final message output
messages = [item for item in output if item.get("type") == "message"]
self.assertGreater(
len(messages), 0, "Response should contain at least one message"
)
# Verify message structure
for msg in messages:
self.assertIn("content", msg)
self.assertIsInstance(msg["content"], list)
# Check content has text
for content_item in msg["content"]:
if content_item.get("type") == "output_text":
self.assertIn("text", content_item)
self.assertIsInstance(content_item["text"], str)
self.assertGreater(len(content_item["text"]), 0)
def test_mcp_basic_tool_call_streaming(self):
"""Test basic MCP tool call (streaming)."""
tools = [
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": "never",
}
]
resp = self.create_response(
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
tools=tools,
stream=True,
)
# Should successfully make the request
self.assertEqual(resp.status_code, 200)
events = self.parse_sse_events(resp)
self.assertGreater(len(events), 0)
event_types = [e.get("event") for e in events]
# Check for lifecycle events
self.assertIn(
"response.created", event_types, "Should have response.created event"
)
self.assertIn(
"response.completed", event_types, "Should have response.completed event"
)
# Check for MCP list tools events
self.assertIn(
"response.output_item.added",
event_types,
"Should have output_item.added events",
)
self.assertIn(
"response.mcp_list_tools.in_progress",
event_types,
"Should have mcp_list_tools.in_progress event",
)
self.assertIn(
"response.mcp_list_tools.completed",
event_types,
"Should have mcp_list_tools.completed event",
)
# Check for MCP call events
self.assertIn(
"response.mcp_call.in_progress",
event_types,
"Should have mcp_call.in_progress event",
)
self.assertIn(
"response.mcp_call_arguments.delta",
event_types,
"Should have mcp_call_arguments.delta event",
)
self.assertIn(
"response.mcp_call_arguments.done",
event_types,
"Should have mcp_call_arguments.done event",
)
self.assertIn(
"response.mcp_call.completed",
event_types,
"Should have mcp_call.completed event",
)
# Check for text output events
self.assertIn(
"response.content_part.added",
event_types,
"Should have content_part.added event",
)
self.assertIn(
"response.output_text.delta",
event_types,
"Should have output_text.delta events",
)
self.assertIn(
"response.output_text.done",
event_types,
"Should have output_text.done event",
)
self.assertIn(
"response.content_part.done",
event_types,
"Should have content_part.done event",
)
# Verify final completed event has full response
completed_events = [e for e in events if e.get("event") == "response.completed"]
self.assertEqual(len(completed_events), 1)
final_response = completed_events[0].get("data", {}).get("response", {})
self.assertIn("id", final_response)
self.assertEqual(final_response.get("status"), "completed")
self.assertIn("output", final_response)
# Verify final output contains expected items
final_output = final_response.get("output", [])
final_output_types = [item.get("type") for item in final_output]
self.assertIn("mcp_list_tools", final_output_types)
self.assertIn("mcp_call", final_output_types)
self.assertIn("message", final_output_types)
# Verify mcp_call items in final output
mcp_calls = [item for item in final_output if item.get("type") == "mcp_call"]
self.assertGreater(len(mcp_calls), 0)
for mcp_call in mcp_calls:
self.assertEqual(mcp_call.get("status"), "completed")
self.assertEqual(mcp_call.get("server_label"), "deepwiki")
self.assertIn("name", mcp_call)
self.assertIn("arguments", mcp_call)
self.assertIn("output", mcp_call)
# Verify text deltas combine to final message
text_deltas = [
e.get("data", {}).get("delta", "")
for e in events
if e.get("event") == "response.output_text.delta"
]
self.assertGreater(len(text_deltas), 0, "Should have text deltas")
# Get final text from output_text.done event
text_done_events = [
e for e in events if e.get("event") == "response.output_text.done"
]
self.assertGreater(len(text_done_events), 0)
final_text = text_done_events[0].get("data", {}).get("text", "")
self.assertGreater(len(final_text), 0, "Final text should not be empty")
"""
Fixtures for launching OpenAI/XAI router for response API e2e testing.
This module provides fixtures for launching SGLang router with OpenAI or XAI backends:
1. Launch router with --backend openai pointing to OpenAI or XAI API
2. Configure history backend (memory or oracle)
This supports testing the Response API against real cloud providers.
"""
import os
import socket
import subprocess
import time
from typing import Optional
import requests
def wait_for_workers_ready(
router_url: str,
expected_workers: int,
timeout: int = 300,
api_key: Optional[str] = None,
) -> None:
"""
Wait for router to have all workers connected.
Polls the /workers endpoint until the 'total' field matches expected_workers.
Example response from /workers endpoint:
{"workers":[],"total":0,"stats":{"prefill_count":0,"decode_count":0,"regular_count":0}}
Args:
router_url: Base URL of router (e.g., "http://127.0.0.1:30000")
expected_workers: Number of workers expected to be connected
timeout: Max seconds to wait
api_key: Optional API key for authentication
"""
start_time = time.time()
last_error = None
attempt = 0
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with requests.Session() as session:
while time.time() - start_time < timeout:
attempt += 1
elapsed = int(time.time() - start_time)
# Print progress every 10 seconds
if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0:
print(f" Still waiting for workers... ({elapsed}/{timeout}s elapsed)")
try:
response = session.get(
f"{router_url}/workers", headers=headers, timeout=5
)
if response.status_code == 200:
data = response.json()
total_workers = data.get("total", 0)
if total_workers == expected_workers:
print(
f" All {expected_workers} workers connected after {elapsed}s"
)
return
else:
last_error = f"Workers: {total_workers}/{expected_workers}"
else:
last_error = f"HTTP {response.status_code}"
except requests.ConnectionError:
last_error = "Connection refused (router not ready yet)"
except requests.Timeout:
last_error = "Timeout"
except requests.RequestException as e:
last_error = str(e)
except (ValueError, KeyError) as e:
last_error = f"Invalid response: {e}"
time.sleep(1)
raise TimeoutError(
f"Router at {router_url} did not get {expected_workers} workers within {timeout}s.\n"
f"Last status: {last_error}\n"
f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs"
)
def find_free_port() -> int:
"""Find an available port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def wait_for_router_ready(
router_url: str,
timeout: int = 60,
api_key: Optional[str] = None,
) -> None:
"""
Wait for router to be ready.
Polls the /health endpoint until it returns 200.
Args:
router_url: Base URL of router (e.g., "http://127.0.0.1:30000")
timeout: Max seconds to wait
api_key: Optional API key for authentication
"""
start_time = time.time()
last_error = None
attempt = 0
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with requests.Session() as session:
while time.time() - start_time < timeout:
attempt += 1
elapsed = int(time.time() - start_time)
# Print progress every 10 seconds
if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0:
print(f" Still waiting for router... ({elapsed}/{timeout}s elapsed)")
try:
response = session.get(
f"{router_url}/health", headers=headers, timeout=5
)
if response.status_code == 200:
print(f" Router ready after {elapsed}s")
return
else:
last_error = f"HTTP {response.status_code}"
except requests.ConnectionError:
last_error = "Connection refused (router not ready yet)"
except requests.Timeout:
last_error = "Timeout"
except requests.RequestException as e:
last_error = str(e)
time.sleep(1)
raise TimeoutError(
f"Router at {router_url} did not become ready within {timeout}s.\n"
f"Last status: {last_error}\n"
f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs"
)
def popen_launch_openai_xai_router(
backend: str, # "openai" or "xai"
base_url: str,
timeout: int = 60,
history_backend: str = "memory",
api_key: Optional[str] = None,
router_args: Optional[list] = None,
stdout=None,
stderr=None,
prometheus_port: Optional[int] = None,
) -> dict:
"""
Launch SGLang router with OpenAI or XAI backend.
This approach:
1. Starts router with --backend openai
2. Points to OpenAI or XAI API via --worker-urls
3. Configures history backend (memory or oracle)
4. Waits for router health check to pass
Args:
backend: "openai" or "xai"
base_url: Base URL for router (e.g., "http://127.0.0.1:30000")
timeout: Timeout for router startup (default: 60s)
history_backend: "memory" or "oracle" (default: memory)
api_key: Optional API key for router authentication
router_args: Additional arguments for router
stdout: Optional file handle for router stdout
stderr: Optional file handle for router stderr
Returns:
dict with:
- router: router process object
- base_url: router URL (HTTP endpoint)
Example:
>>> cluster = popen_launch_openai_xai_router(
... "openai", "http://127.0.0.1:30000"
... )
>>> # Use cluster['base_url'] for HTTP requests
>>> # Cleanup:
>>> kill_process_tree(cluster['router'].pid)
"""
show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1"
# Parse router port from base_url
if ":" in base_url.split("//")[-1]:
router_port = int(base_url.split(":")[-1])
else:
router_port = find_free_port()
print(f"\n{'='*70}")
print(f"Launching {backend.upper()} router")
print(f"{'='*70}")
print(f" Backend: {backend}")
print(f" Router port: {router_port}")
print(f" History backend: {history_backend}")
# Determine worker URL based on backend
if backend == "openai":
worker_url = "https://api.openai.com"
# Get API key from environment
backend_api_key = os.environ.get("OPENAI_API_KEY")
if not backend_api_key:
raise ValueError(
"OPENAI_API_KEY environment variable must be set for OpenAI backend"
)
elif backend == "xai":
worker_url = "https://api.x.ai"
# Get API key from environment
backend_api_key = os.environ.get("XAI_API_KEY")
if not backend_api_key:
raise ValueError(
"XAI_API_KEY environment variable must be set for XAI backend"
)
else:
raise ValueError(f"Unsupported backend: {backend}")
print(f" Worker URL: {worker_url}")
# Build router command
router_cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(router_port),
"--backend",
"openai",
"--worker-urls",
worker_url,
"--history-backend",
history_backend,
"--log-level",
"warn",
]
# Note: Not adding --api-key to router command for local testing
# The router will not require authentication
# Add Prometheus port to avoid conflicts (use unique port or disable)
if prometheus_port is None:
# Auto-assign a unique prometheus port based on router port
prometheus_port = router_port + 1000
router_cmd.extend(["--prometheus-port", str(prometheus_port)])
# Add router-specific args
if router_args:
router_cmd.extend(router_args)
if show_output:
print(f" Command: {' '.join(router_cmd)}")
# Set up environment with backend API key
env = os.environ.copy()
if backend == "openai":
env["OPENAI_API_KEY"] = backend_api_key
else:
env["XAI_API_KEY"] = backend_api_key
# Launch router
if show_output:
router_proc = subprocess.Popen(
router_cmd,
env=env,
stdout=stdout,
stderr=stderr,
)
else:
router_proc = subprocess.Popen(
router_cmd,
stdout=stdout if stdout is not None else subprocess.PIPE,
stderr=stderr if stderr is not None else subprocess.PIPE,
env=env,
)
print(f" PID: {router_proc.pid}")
# Wait for router to be ready
router_url = f"http://127.0.0.1:{router_port}"
print(f"\nWaiting for router to start at {router_url}...")
try:
wait_for_router_ready(router_url, timeout=timeout, api_key=None)
print(f"✓ Router ready at {router_url}")
except TimeoutError:
print(f"✗ Router failed to start")
# Cleanup: kill router
try:
router_proc.kill()
except:
pass
raise
print(f"\n{'='*70}")
print(f"✓ {backend.upper()} router ready!")
print(f" Router: {router_url}")
print(f"{'='*70}\n")
return {
"router": router_proc,
"base_url": router_url,
}
def popen_launch_workers_and_router(
model: str,
base_url: str,
timeout: int = 300,
num_workers: int = 2,
policy: str = "round_robin",
api_key: Optional[str] = None,
worker_args: Optional[list] = None,
router_args: Optional[list] = None,
tp_size: int = 1,
env: Optional[dict] = None,
stdout=None,
stderr=None,
) -> dict:
"""
Launch SGLang workers and gRPC router separately.
This approach:
1. Starts N SGLang workers with --grpc-mode flag
2. Waits for workers to initialize (process startup)
3. Starts a gRPC router pointing to those workers
4. Waits for router health check to pass (router validates worker connectivity)
This matches production deployment patterns better than the integrated approach.
Args:
model: Model path (e.g., /home/ubuntu/models/llama-3.1-8b-instruct)
base_url: Base URL for router (e.g., "http://127.0.0.1:8080")
timeout: Timeout for server startup (default: 300s)
num_workers: Number of workers to launch
policy: Routing policy (round_robin, random, power_of_two, cache_aware)
api_key: Optional API key for router
worker_args: Additional arguments for workers (e.g., ["--context-len", "8192"])
router_args: Additional arguments for router (e.g., ["--max-total-token", "1536"])
tp_size: Tensor parallelism size for workers (default: 1)
env: Optional environment variables for workers (e.g., {"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256"})
stdout: Optional file handle for worker stdout (default: subprocess.PIPE)
stderr: Optional file handle for worker stderr (default: subprocess.PIPE)
Returns:
dict with:
- workers: list of worker process objects
- worker_urls: list of gRPC worker URLs
- router: router process object
- base_url: router URL (HTTP endpoint)
Example:
>>> cluster = popen_launch_workers_and_router(model, base_url, num_workers=2)
>>> # Use cluster['base_url'] for HTTP requests
>>> # Cleanup:
>>> for worker in cluster['workers']:
>>> kill_process_tree(worker.pid)
>>> kill_process_tree(cluster['router'].pid)
"""
show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1"
# Parse router port from base_url
if ":" in base_url.split("//")[-1]:
router_port = int(base_url.split(":")[-1])
else:
router_port = find_free_port()
print(f"\n{'='*70}")
print(f"Launching gRPC cluster (separate workers + router)")
print(f"{'='*70}")
print(f" Model: {model}")
print(f" Router port: {router_port}")
print(f" Workers: {num_workers}")
print(f" TP size: {tp_size}")
print(f" Policy: {policy}")
# Step 1: Launch workers with gRPC enabled
workers = []
worker_urls = []
for i in range(num_workers):
worker_port = find_free_port()
worker_url = f"grpc://127.0.0.1:{worker_port}"
worker_urls.append(worker_url)
print(f"\n[Worker {i+1}/{num_workers}]")
print(f" Port: {worker_port}")
print(f" URL: {worker_url}")
# Build worker command
worker_cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(worker_port),
"--grpc-mode", # Enable gRPC for this worker
"--mem-fraction-static",
"0.8",
"--attention-backend",
"fa3",
]
# Add TP size
if tp_size > 1:
worker_cmd.extend(["--tp-size", str(tp_size)])
# Add worker-specific args
if worker_args:
worker_cmd.extend(worker_args)
# Launch worker with optional environment variables
if show_output:
worker_proc = subprocess.Popen(
worker_cmd,
env=env,
stdout=stdout,
stderr=stderr,
)
else:
worker_proc = subprocess.Popen(
worker_cmd,
stdout=stdout if stdout is not None else subprocess.PIPE,
stderr=stderr if stderr is not None else subprocess.PIPE,
env=env,
)
workers.append(worker_proc)
print(f" PID: {worker_proc.pid}")
# Give workers a moment to start binding to ports
# The router will check worker health when it starts
print(f"\nWaiting for {num_workers} workers to initialize (20s)...")
time.sleep(20)
# Quick check: make sure worker processes are still alive
for i, worker in enumerate(workers):
if worker.poll() is not None:
print(f" ✗ Worker {i+1} died during startup (exit code: {worker.poll()})")
# Cleanup: kill all workers
for w in workers:
try:
w.kill()
except:
pass
raise RuntimeError(f"Worker {i+1} failed to start")
print(f"✓ All {num_workers} workers started (router will verify connectivity)")
# Step 2: Launch router pointing to workers
print(f"\n[Router]")
print(f" Port: {router_port}")
print(f" Worker URLs: {', '.join(worker_urls)}")
# Build router command
router_cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(router_port),
"--prometheus-port",
"9321",
"--policy",
policy,
"--model-path",
model,
]
# Add worker URLs
router_cmd.append("--worker-urls")
router_cmd.extend(worker_urls)
# Add API key
if api_key:
router_cmd.extend(["--api-key", api_key])
# Add router-specific args
if router_args:
router_cmd.extend(router_args)
if show_output:
print(f" Command: {' '.join(router_cmd)}")
# Launch router
if show_output:
router_proc = subprocess.Popen(router_cmd)
else:
router_proc = subprocess.Popen(
router_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(f" PID: {router_proc.pid}")
# Wait for router to be ready
router_url = f"http://127.0.0.1:{router_port}"
print(f"\nWaiting for router to start at {router_url}...")
try:
wait_for_workers_ready(
router_url, expected_workers=num_workers, timeout=180, api_key=api_key
)
print(f"✓ Router ready at {router_url}")
except TimeoutError:
print(f"✗ Router failed to start")
# Cleanup: kill router and all workers
try:
router_proc.kill()
except:
pass
for worker in workers:
try:
worker.kill()
except:
pass
raise
print(f"\n{'='*70}")
print(f"✓ gRPC cluster ready!")
print(f" Router: {router_url}")
print(f" Workers: {len(workers)}")
print(f"{'='*70}\n")
return {
"workers": workers,
"worker_urls": worker_urls,
"router": router_proc,
"base_url": router_url,
}
"""
State management tests for Response API.
Tests both previous_response_id and conversation-based state management.
These tests should work across all backends (OpenAI, XAI, gRPC).
"""
import unittest
from base import ResponseAPIBaseTest
class StateManagementTests(ResponseAPIBaseTest):
"""Tests for state management using previous_response_id and conversation."""
def test_previous_response_id_chaining(self):
"""Test chaining responses using previous_response_id."""
# First response
resp1 = self.create_response(
"My name is Alice and my friend is Bob. Remember it."
)
self.assertEqual(resp1.status_code, 200)
response1_id = resp1.json()["id"]
# Second response referencing first
resp2 = self.create_response(
"What is my name", previous_response_id=response1_id
)
self.assertEqual(resp2.status_code, 200)
response2_data = resp2.json()
# The model should remember the name from previous response
output_text = self._extract_output_text(response2_data)
self.assertIn("Alice", output_text)
# Third response referencing second
resp3 = self.create_response(
"What is my friend name?",
previous_response_id=response2_data["id"],
)
response3_data = resp3.json()
output_text = self._extract_output_text(response3_data)
self.assertEqual(resp3.status_code, 200)
self.assertIn("Bob", output_text)
@unittest.skip("TODO: Add the invalid previous_response_id check")
def test_previous_response_id_invalid(self):
"""Test using invalid previous_response_id."""
resp = self.create_response(
"Test", previous_response_id="resp_invalid123", max_output_tokens=50
)
# Should return 404 or 400 for invalid response ID
if resp.status_code != 200:
print(f"\n❌ Response creation failed!")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.text}")
self.assertIn(resp.status_code, [400, 404])
def test_conversation_with_multiple_turns(self):
"""Test state management using conversation ID."""
# Create conversation
conv_resp = self.create_conversation(metadata={"topic": "math"})
self.assertEqual(conv_resp.status_code, 200)
conversation_id = conv_resp.json()["id"]
# First response in conversation
resp1 = self.create_response("I have 5 apples.", conversation=conversation_id)
self.assertEqual(resp1.status_code, 200)
# Second response in same conversation
resp2 = self.create_response(
"How many apples do I have?",
conversation=conversation_id,
)
self.assertEqual(resp2.status_code, 200)
output_text = self._extract_output_text(resp2.json())
# Should remember "5 apples"
self.assertTrue("5" in output_text or "five" in output_text.lower())
# Third response in same conversation
resp3 = self.create_response(
"If I get 3 more, how many total?",
conversation=conversation_id,
)
self.assertEqual(resp3.status_code, 200)
output_text = self._extract_output_text(resp3.json())
# Should calculate 5 + 3 = 8
self.assertTrue("8" in output_text or "eight" in output_text.lower())
list_resp = self.list_conversation_items(conversation_id)
self.assertEqual(list_resp.status_code, 200)
items = list_resp.json()["data"]
# Should have at least 6 items (3 inputs + 3 outputs)
self.assertGreaterEqual(len(items), 6)
def test_mutually_exclusive_parameters(self):
"""Test that previous_response_id and conversation are mutually exclusive."""
# Create conversation and response
conv_resp = self.create_conversation()
conversation_id = conv_resp.json()["id"]
resp1 = self.create_response("Test")
response1_id = resp1.json()["id"]
# Try to use both parameters
resp = self.create_response(
"This should fail",
previous_response_id=response1_id,
conversation=conversation_id,
)
# Should return 400 Bad Request
self.assertEqual(resp.status_code, 400)
error_data = resp.json()
self.assertIn("error", error_data)
self.assertIn("mutually exclusive", error_data["error"]["message"].lower())
# Helper methods
def _extract_output_text(self, response_data: dict) -> str:
"""Extract text content from response output."""
output = response_data.get("output", [])
if not output:
return ""
text_parts = []
for item in output:
content = item.get("content", [])
for part in content:
if part.get("type") == "output_text":
text_parts.append(part.get("text", ""))
return " ".join(text_parts)
"""
OpenAI backend tests for Response API.
Run with:
export OPENAI_API_KEY=your_key
python3 -m pytest py_test/e2e_response_api/test_openai_backend.py -v
python3 -m unittest e2e_response_api.test_openai_backend.TestOpenAIStateManagement
"""
import os
import sys
import unittest
from pathlib import Path
# Add current directory for imports
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR))
# Import local modules
from base import ConversationCRUDBaseTest, ResponseCRUDBaseTest
from mcp import MCPTests
from router_fixtures import (
popen_launch_openai_xai_router,
popen_launch_workers_and_router,
)
from state_management import StateManagementTests
from util import kill_process_tree
class TestOpenaiBackend(
ResponseCRUDBaseTest, ConversationCRUDBaseTest, StateManagementTests, MCPTests
):
"""End to end tests for OpenAI backend."""
api_key = os.environ.get("OPENAI_API_KEY")
@classmethod
def setUpClass(cls):
cls.model = "gpt-5-nano"
cls.base_url_port = "http://127.0.0.1:30010"
cls.cluster = popen_launch_openai_xai_router(
backend="openai",
base_url=cls.base_url_port,
history_backend="memory",
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
class TestXaiBackend(StateManagementTests):
"""End to end tests for XAI backend."""
api_key = os.environ.get("XAI_API_KEY")
@classmethod
def setUpClass(cls):
cls.model = "grok-4-fast"
cls.base_url_port = "http://127.0.0.1:30023"
cls.cluster = popen_launch_openai_xai_router(
backend="xai",
base_url=cls.base_url_port,
history_backend="memory",
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
class TestGrpcBackend(StateManagementTests, MCPTests):
"""End to end tests for gRPC backend."""
@classmethod
def setUpClass(cls):
cls.model = "/home/ubuntu/models/meta-llama/Llama-3.1-8B-Instruct"
cls.base_url_port = "http://127.0.0.1:30030"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url_port,
timeout=90,
num_workers=1,
tp_size=2,
policy="round_robin",
router_args=["--history-backend", "memory"],
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
@unittest.skip(
"TODO: transport error, details: [], metadata: MetadataMap { headers: {} }"
)
def test_previous_response_id_chaining(self):
super().test_previous_response_id_chaining()
@unittest.skip("TODO: return 501 Not Implemented")
def test_conversation_with_multiple_turns(self):
super().test_conversation_with_multiple_turns()
@unittest.skip("TODO: decode error message")
def test_mutually_exclusive_parameters(self):
super().test_mutually_exclusive_parameters()
@unittest.skip(
"TODO: Pipeline execution failed: Pipeline stage WorkerSelection failed"
)
def test_mcp_basic_tool_call(self):
super().test_mcp_basic_tool_call()
@unittest.skip("TODO: no event fields")
def test_mcp_basic_tool_call_streaming(self):
return super().test_mcp_basic_tool_call_streaming()
if __name__ == "__main__":
unittest.main()
"""
Utility functions for Response API e2e tests.
"""
import os
import signal
import threading
import unittest
import psutil
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""
Kill the process and all its child processes.
Args:
parent_pid: PID of the parent process
include_parent: Whether to kill the parent process itself
skip_pid: Optional PID to skip during cleanup
"""
# Remove sigchld handler to avoid spammy logs
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
if parent_pid is None:
parent_pid = os.getpid()
include_parent = False
try:
itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
try:
child.kill()
except psutil.NoSuchProcess:
pass
if include_parent:
try:
itself.kill()
except psutil.NoSuchProcess:
pass
class CustomTestCase(unittest.TestCase):
"""
Custom test case base class with retry support.
This provides automatic test retry functionality based on environment variables.
"""
def _callTestMethod(self, method):
"""Override to add retry logic."""
max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0"))
if max_retry == 0:
# No retry, just run once
return super(CustomTestCase, self)._callTestMethod(method)
# Retry logic
for attempt in range(max_retry + 1):
try:
return super(CustomTestCase, self)._callTestMethod(method)
except Exception as e:
if attempt < max_retry:
print(
f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..."
)
continue
else:
# Last attempt, re-raise the exception
raise
def setUp(self):
"""Print test method name at the start of each test."""
print(f"[Test Method] {self._testMethodName}", flush=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment