Unverified Commit 63cfe1b0 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Add gRPC E2E test suite (#11790)

parent 70f6309c
...@@ -86,7 +86,7 @@ jobs: ...@@ -86,7 +86,7 @@ jobs:
pytest-rust: pytest-rust:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: 4-gpu-a10 runs-on: 4-gpu-a10
timeout-minutes: 25 timeout-minutes: 32
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -144,6 +144,12 @@ jobs: ...@@ -144,6 +144,12 @@ jobs:
python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2 python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2
pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
- name: Run Python E2E gRPC tests
run: |
bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router
SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO
- name: Upload benchmark results - name: Upload benchmark results
if: success() if: success()
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
......
"""
gRPC Router E2E Test - OpenAI Server API Compatibility
This test file is REUSED from test/srt/openai_server/basic/test_openai_server.py
with minimal changes:
- Swap popen_launch_server() → popen_launch_workers_and_router()
- Update teardown to cleanup router + workers
- All test logic and assertions remain identical
Run with:
python3 -m pytest e2e_grpc/basic/test_openai_server.py -v
python3 -m unittest e2e_grpc.basic.test_openai_server.TestOpenAIServer.test_completion
"""
import json
# CHANGE: Import router launcher instead of server launcher
import sys
import unittest
from pathlib import Path
import openai
import requests
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestOpenAIServer(CustomTestCase):
"""
Test OpenAI API through gRPC router.
REUSED from test/srt/openai_server/basic/test_openai_server.py
ONLY CHANGE: Server launch mechanism
- Launches SGLang workers with --enable-grpc
- Launches gRPC router pointing to those workers
"""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
num_workers=1,
tp_size=2,
policy="round_robin",
api_key=cls.api_key,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
# ALL TEST METHODS BELOW ARE UNCHANGED FROM ORIGINAL
# They validate that the router maintains OpenAI API compatibility
def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
n=parallel_sample_num,
)
if logprobs:
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert len(response.choices) == parallel_sample_num
assert response.choices[0].message.role == "assistant"
assert isinstance(response.choices[0].message.content, str)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
is_firsts = {}
is_finished = {}
finish_reason_counts = {}
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
data = response.choices[0].delta
if is_firsts.get(index, True):
assert (
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False
continue
if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
), f"top_logprobs token was not a string"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
), f"top_logprobs was not a list"
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
assert response.created
for index in [i for i in range(parallel_sample_num)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
for index in range(parallel_sample_num):
assert (
index in finish_reason_counts
), f"No finish_reason found for index {index}"
assert (
finish_reason_counts[index] == 1
), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}"
def test_chat_completion(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion(logprobs, parallel_sample_num)
def test_chat_completion_stream(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_penalty(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
def test_response_prefill(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": """
Extract the name, size, price, and color from this product description as a JSON object:
<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
},
{
"role": "assistant",
"content": "{\n",
},
],
temperature=0,
extra_body={"continue_final_message": True},
)
assert (
response.choices[0]
.message.content.strip()
.startswith('"name": "SmartHome Mini",')
)
def test_model_list(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# TODO: Update the logic here when router /v1/models response format matching the openai api standard
models = list(client.models.list().models)
assert len(models) == 1
# assert isinstance(getattr(models[0], "max_model_len", None), int)
@unittest.skip("Skipping retrieve model test as it is not supported by the router")
def test_retrieve_model(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
retrieved_model = client.models.retrieve(self.model)
self.assertEqual(retrieved_model.id, self.model)
self.assertEqual(retrieved_model.root, self.model)
with self.assertRaises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
if __name__ == "__main__":
unittest.main()
"""
Pytest configuration for gRPC router e2e tests.
This module provides shared fixtures that can be used across all gRPC router tests.
"""
import sys
from pathlib import Path
import pytest
# Ensure router py_src is importable
_ROUTER_ROOT = Path(__file__).resolve().parents[2]
_ROUTER_SRC = _ROUTER_ROOT / "py_src"
if str(_ROUTER_SRC) not in sys.path:
sys.path.insert(0, str(_ROUTER_SRC))
# Ensure e2e_grpc test utilities are importable
_E2E_GRPC_DIR = Path(__file__).parent
if str(_E2E_GRPC_DIR) not in sys.path:
sys.path.insert(0, str(_E2E_GRPC_DIR))
# Pytest markers for test organization
def pytest_configure(config):
config.addinivalue_line("markers", "e2e: end-to-end tests with real workers")
config.addinivalue_line("markers", "grpc: gRPC-specific tests")
config.addinivalue_line("markers", "slow: slow-running tests")
config.addinivalue_line("markers", "pd: prefill-decode disaggregation tests")
"""
Usage:
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
"""
import asyncio
import json
import os
import sys
import time
import unittest
# CHANGE: Import router launcher instead of server launcher
from pathlib import Path
import openai
import requests
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_ENABLE_THINKING_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestEnableThinking(CustomTestCase):
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
cls.model = DEFAULT_ENABLE_THINKING_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=120,
api_key=cls.api_key,
router_args=[
"--reasoning-parser",
"qwen3",
],
num_workers=1,
tp_size=4,
)
cls.additional_chat_kwargs = {}
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_chat_completion_with_reasoning(self):
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
self.assertIn("reasoning_content", data["choices"][0]["message"])
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
def test_chat_completion_without_reasoning(self):
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
if "reasoning_content" in data["choices"][0]["message"]:
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
def test_stream_chat_completion_with_reasoning(self):
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream With Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertTrue(
has_reasoning,
"The reasoning content is not included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
def test_stream_chat_completion_without_reasoning(self):
# Test streaming with "enable_thinking": False, reasoning_content should be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream Without Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertFalse(
has_reasoning,
"The reasoning content should not be included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
if __name__ == "__main__":
unittest.main()
"""
Usage:
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_streaming
"""
import json
# CHANGE: Import router launcher instead of server launcher
import sys
import unittest
from pathlib import Path
import openai
import requests
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_REASONING_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
kill_process_tree,
)
class TestReasoningContentAPI(CustomTestCase):
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
cls.model = DEFAULT_REASONING_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
"--reasoning-parser",
"deepseek_r1",
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_streaming_separate_reasoning_false(self):
# Test streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) == 0
assert len(content) > 0
def test_streaming_separate_reasoning_true(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) > 0
assert len(content) > 0
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
first_chunk = False
for chunk in response:
if chunk.choices[0].delta.reasoning_content:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
if not first_chunk:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if not first_chunk:
assert (
not chunk.choices[0].delta.reasoning_content
or len(chunk.choices[0].delta.reasoning_content) == 0
)
assert len(reasoning_content) > 0
assert len(content) > 0
def test_nonstreaming_separate_reasoning_false(self):
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
assert (
not response.choices[0].message.reasoning_content
or len(response.choices[0].message.reasoning_content) == 0
)
assert len(response.choices[0].message.content) > 0
def test_nonstreaming_separate_reasoning_true(self):
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
assert len(response.choices[0].message.reasoning_content) > 0
assert len(response.choices[0].message.content) > 0
if __name__ == "__main__":
unittest.main()
"""
Fixtures for launching gRPC router + workers for e2e testing.
This module provides fixtures for launching SGLang workers and gRPC router separately:
1. Launch N SGLang workers with gRPC enabled
2. Launch router pointing to those workers
This approach gives more control and matches production deployment patterns.
"""
import socket
import subprocess
import time
from typing import Optional
import requests
def find_free_port() -> int:
"""Find an available port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def wait_for_workers_ready(
router_url: str,
expected_workers: int,
timeout: int = 300,
api_key: Optional[str] = None,
) -> None:
"""
Wait for router to have all workers connected.
Polls the /workers endpoint until the 'total' field matches expected_workers.
Example response from /workers endpoint:
{"workers":[],"total":0,"stats":{"prefill_count":0,"decode_count":0,"regular_count":0}}
Args:
router_url: Base URL of router (e.g., "http://127.0.0.1:30000")
expected_workers: Number of workers expected to be connected
timeout: Max seconds to wait
api_key: Optional API key for authentication
"""
start_time = time.time()
last_error = None
attempt = 0
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with requests.Session() as session:
while time.time() - start_time < timeout:
attempt += 1
elapsed = int(time.time() - start_time)
# Print progress every 10 seconds
if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0:
print(f" Still waiting for workers... ({elapsed}/{timeout}s elapsed)")
try:
response = session.get(
f"{router_url}/workers", headers=headers, timeout=5
)
if response.status_code == 200:
data = response.json()
total_workers = data.get("total", 0)
if total_workers == expected_workers:
print(
f" All {expected_workers} workers connected after {elapsed}s"
)
return
else:
last_error = f"Workers: {total_workers}/{expected_workers}"
else:
last_error = f"HTTP {response.status_code}"
except requests.ConnectionError:
last_error = "Connection refused (router not ready yet)"
except requests.Timeout:
last_error = "Timeout"
except requests.RequestException as e:
last_error = str(e)
except (ValueError, KeyError) as e:
last_error = f"Invalid response: {e}"
time.sleep(1)
raise TimeoutError(
f"Router at {router_url} did not get {expected_workers} workers within {timeout}s.\n"
f"Last status: {last_error}\n"
f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs"
)
def popen_launch_workers_and_router(
model: str,
base_url: str,
timeout: int = 300,
num_workers: int = 2,
policy: str = "round_robin",
api_key: Optional[str] = None,
worker_args: Optional[list] = None,
router_args: Optional[list] = None,
tp_size: int = 1,
env: Optional[dict] = None,
stdout=None,
stderr=None,
) -> dict:
"""
Launch SGLang workers and gRPC router separately.
This approach:
1. Starts N SGLang workers with --grpc-mode flag
2. Waits for workers to initialize (process startup)
3. Starts a gRPC router pointing to those workers
4. Waits for router health check to pass (router validates worker connectivity)
This matches production deployment patterns better than the integrated approach.
Args:
model: Model path (e.g., /home/ubuntu/models/llama-3.1-8b-instruct)
base_url: Base URL for router (e.g., "http://127.0.0.1:8080")
timeout: Timeout for server startup (default: 300s)
num_workers: Number of workers to launch
policy: Routing policy (round_robin, random, power_of_two, cache_aware)
api_key: Optional API key for router
worker_args: Additional arguments for workers (e.g., ["--context-len", "8192"])
router_args: Additional arguments for router (e.g., ["--max-total-token", "1536"])
tp_size: Tensor parallelism size for workers (default: 1)
env: Optional environment variables for workers (e.g., {"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256"})
stdout: Optional file handle for worker stdout (default: subprocess.PIPE)
stderr: Optional file handle for worker stderr (default: subprocess.PIPE)
Returns:
dict with:
- workers: list of worker process objects
- worker_urls: list of gRPC worker URLs
- router: router process object
- base_url: router URL (HTTP endpoint)
Example:
>>> cluster = popen_launch_workers_and_router(model, base_url, num_workers=2)
>>> # Use cluster['base_url'] for HTTP requests
>>> # Cleanup:
>>> for worker in cluster['workers']:
>>> kill_process_tree(worker.pid)
>>> kill_process_tree(cluster['router'].pid)
"""
import os
show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1"
# Note: timeout parameter is used for router health check below
# Parse router port from base_url
if ":" in base_url.split("//")[-1]:
router_port = int(base_url.split(":")[-1])
else:
router_port = find_free_port()
print(f"\n{'='*70}")
print(f"Launching gRPC cluster (separate workers + router)")
print(f"{'='*70}")
print(f" Model: {model}")
print(f" Router port: {router_port}")
print(f" Workers: {num_workers}")
print(f" TP size: {tp_size}")
print(f" Policy: {policy}")
# Step 1: Launch workers with gRPC enabled
workers = []
worker_urls = []
for i in range(num_workers):
worker_port = find_free_port()
worker_url = f"grpc://127.0.0.1:{worker_port}"
worker_urls.append(worker_url)
print(f"\n[Worker {i+1}/{num_workers}]")
print(f" Port: {worker_port}")
print(f" URL: {worker_url}")
# Build worker command
worker_cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(worker_port),
"--grpc-mode", # Enable gRPC for this worker
"--mem-fraction-static",
"0.8",
"--attention-backend",
"fa3",
]
# Add TP size
if tp_size > 1:
worker_cmd.extend(["--tp-size", str(tp_size)])
# Add worker-specific args
if worker_args:
worker_cmd.extend(worker_args)
# Launch worker with optional environment variables
if show_output:
worker_proc = subprocess.Popen(
worker_cmd,
env=env,
stdout=stdout,
stderr=stderr,
)
else:
worker_proc = subprocess.Popen(
worker_cmd,
stdout=stdout if stdout is not None else subprocess.PIPE,
stderr=stderr if stderr is not None else subprocess.PIPE,
env=env,
)
workers.append(worker_proc)
print(f" PID: {worker_proc.pid}")
# Give workers a moment to start binding to ports
# The router will check worker health when it starts
print(f"\nWaiting for {num_workers} workers to initialize (20s)...")
time.sleep(20)
# Quick check: make sure worker processes are still alive
for i, worker in enumerate(workers):
if worker.poll() is not None:
print(f" ✗ Worker {i+1} died during startup (exit code: {worker.poll()})")
# Cleanup: kill all workers
for w in workers:
try:
w.kill()
except:
pass
raise RuntimeError(f"Worker {i+1} failed to start")
print(f"✓ All {num_workers} workers started (router will verify connectivity)")
# Step 2: Launch router pointing to workers
print(f"\n[Router]")
print(f" Port: {router_port}")
print(f" Worker URLs: {', '.join(worker_urls)}")
# Build router command
router_cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(router_port),
"--prometheus-port",
"9321",
"--policy",
policy,
"--model-path",
model,
]
# Add worker URLs
router_cmd.append("--worker-urls")
router_cmd.extend(worker_urls)
# Add API key
if api_key:
router_cmd.extend(["--api-key", api_key])
# Add router-specific args
if router_args:
router_cmd.extend(router_args)
if show_output:
print(f" Command: {' '.join(router_cmd)}")
# Launch router
if show_output:
router_proc = subprocess.Popen(router_cmd)
else:
router_proc = subprocess.Popen(
router_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(f" PID: {router_proc.pid}")
# Wait for router to be ready
router_url = f"http://127.0.0.1:{router_port}"
print(f"\nWaiting for router to start at {router_url}...")
try:
wait_for_workers_ready(
router_url, expected_workers=num_workers, timeout=180, api_key=api_key
)
print(f"✓ Router ready at {router_url}")
except TimeoutError:
print(f"✗ Router failed to start")
# Cleanup: kill router and all workers
try:
router_proc.kill()
except:
pass
for worker in workers:
try:
worker.kill()
except:
pass
raise
print(f"\n{'='*70}")
print(f"✓ gRPC cluster ready!")
print(f" Router: {router_url}")
print(f" Workers: {len(workers)}")
print(f"{'='*70}\n")
return {
"workers": workers,
"worker_urls": worker_urls,
"router": router_proc,
"base_url": router_url,
}
"""
gRPC Router E2E Test - Test Openai Function Calling
This test file is REUSED from test/srt/openai_server/function_call/test_openai_function_calling.py
with minimal changes:
num_workers=2,
- Swap popen_launch_server() → popen_launch_workers_and_router()
- Update teardown to cleanup router + workers
- All test logic and assertions remain identical
Run with:
pytest py_test/e2e_grpc/e2e_grpc/function_call/test_openai_function_calling.py -v
"""
import json
# CHANGE: Import router launcher instead of server launcher
import sys
import time
import unittest
from pathlib import Path
import openai
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_MODEL_PATH,
DEFAULT_SMALL_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestOpenAIServerFunctionCalling(CustomTestCase):
# NOTE: this system_message is for Llama3.2 system prompt. Without this,
# sometimes Llama3.2 gives a different tool call format such as:
# '<|python_tag|>{"type": "function", "function": "add", "parameters": {"a": "3", "b": "5"}}'
SYSTEM_MESSAGE = (
"You are a helpful assistant with tool calling capabilities. "
"Only reply with a tool call if the function exists in the library provided by the user. "
"If it doesn't exist, just reply directly in natural language. "
"When you receive a tool call response, use the output to format an answer to the original user question. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. '
"Do not use variables.\n\n"
)
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
# Using small model for function calling tests
cls.model = DEFAULT_SMALL_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
# If your server needs extra parameters to test function calling, please add them here.
"--tool-call-parser",
"llama",
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_function_calling_format(self):
"""
Test: Whether the function call format returned by the AI is correct.
When returning a tool call, message.content should be None, and tool_calls should be a list.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "integer",
"description": "A number",
},
"b": {
"type": "integer",
"description": "A number",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [
{"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "Compute (3+5)"},
]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
tool_calls = response.choices[0].message.tool_calls
assert (
isinstance(tool_calls, list) and len(tool_calls) > 0
), "tool_calls should be a non-empty list"
function_name = tool_calls[0].function.name
assert function_name == "add", "Function name should be 'add'"
# This unit test is too difficult for default model. Mark it as optional unit tests so it won't trigger unless specified.
def _test_function_calling_multiturn(self):
"""
Test: Whether the function call format returned by the AI is correct.
When returning a tool call, message.content should be None, and tool_calls should be a list.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "integer",
"description": "A number",
},
"b": {
"type": "integer",
"description": "A number",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [{"role": "user", "content": "Compute (3+5)"}]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
assert function_name == "add", "Function name should be 'add'"
function_arguments = tool_call.function.arguments
function_arguments = json.loads(tool_call.function.arguments)
assert function_arguments in [
{"a": 3, "b": 5},
{"a": "3", "b": "5"},
], f"Unexpected function arguments: {function_arguments}"
messages.append(response.choices[0].message)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "8",
"name": function_name,
}
)
final_response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
assert (
"8" in final_response.choices[0].message.content
), "tool_call response should have the sum 8 in the content"
def test_function_calling_streaming_simple(self):
"""
Test: Whether the function name can be correctly recognized in streaming mode.
- Expect a function call to be found, and the function name to be correct.
- Verify that streaming mode returns at least multiple chunks.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [
{"role": "system", "content": self.SYSTEM_MESSAGE},
{
"role": "user",
"content": "What is the temperature in Paris in celsius??",
},
]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_function_name = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
if choice.delta.tool_calls:
tool_call = choice.delta.tool_calls[0]
if tool_call.function.name:
self.assertEqual(
tool_call.function.name,
"get_current_weather",
"Function name should be 'get_current_weather'",
)
found_function_name = True
break
self.assertTrue(
found_function_name,
"Target function name 'get_current_weather' was not found in the streaming chunks",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"tool_calls",
"Final response of function calling should have finish_reason 'tool_calls'",
)
def test_function_calling_streaming_args_parsing(self):
"""
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
- The user request requires multiple parameters.
- AI may return the arguments in chunks that need to be concatenated.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two integers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "integer",
"description": "First integer",
},
"b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["a", "b"],
},
"strict": True, # Llama-3.2-1B is flaky in tool call. It won't always respond with parameters unless we set strict.
},
}
]
messages = [
{"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "Please sum 5 and 7, just call the function."},
]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.9,
top_p=0.9,
stream=True,
tools=tools,
)
argument_fragments = []
chunks = list(response_stream)
function_name = None
for chunk in chunks:
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = choice.delta.tool_calls[0]
# Record the function name on first occurrence
function_name = tool_call.function.name or function_name
# In case of multiple chunks, JSON fragments may need to be concatenated
if tool_call.function.arguments is not None:
argument_fragments.append(tool_call.function.arguments)
self.assertEqual(function_name, "add", "Function name should be 'add'")
joined_args = "".join(argument_fragments)
self.assertTrue(
len(joined_args) > 0,
"No parameter fragments were returned in the function call",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"tool_calls",
"Final response of function calling should have finish_reason 'tool_calls'",
)
# Check whether the concatenated JSON is valid
try:
args_obj = json.loads(joined_args)
except json.JSONDecodeError:
self.fail(
"The concatenated tool call arguments are not valid JSON, parsing failed"
)
self.assertIn("a", args_obj, "Missing parameter 'a'")
self.assertIn("b", args_obj, "Missing parameter 'b'")
self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
@unittest.skip(
"Skipping function call strict test as it is not supported by the router"
)
def test_function_call_strict(self):
"""
Test: Whether the strict mode of function calling works as expected.
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
}
]
messages = [
{"role": "user", "content": "Please compute 5 - 7, using your tool."}
]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
tool_calls = response.choices[0].message.tool_calls
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(function_name, "sub", "Function name should be 'sub'")
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
def test_function_call_required(self):
"""
Test: Whether tool_choice: "required" works as expected
- When tool_choice == "required", the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice="required",
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name,
"get_weather",
f"Function name should be 'get_weather', got: {function_name}",
)
self.assertIn(
"city", args_obj, f"Function arguments should have 'city', got: {args_obj}"
)
# Make the test more robust by checking type and accepting valid responses
city_value = args_obj["city"]
self.assertIsInstance(
city_value,
str,
f"Parameter city should be a string, got: {type(city_value)}",
)
self.assertTrue(
"Paris" in city_value or "France" in city_value,
f"Parameter city should contain either 'Paris' or 'France', got: {city_value}",
)
def test_function_call_specific(self):
"""
Test: Whether tool_choice: ToolChoice works as expected
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
def test_streaming_multiple_choices_finish_reason(self):
"""
Test: Verify that each choice gets its own finish_reason chunk in streaming mode with n > 1.
This tests the fix for the bug where only the last index got a finish_reason chunk.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
messages = [
{"role": "user", "content": "What is the weather like in Los Angeles?"}
]
# Request with n=2 to get multiple choices
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=2048,
temperature=0.8,
stream=True,
tools=tools,
tool_choice="required", # Force tool calls
n=2, # Multiple choices
)
chunks = list(response_stream)
# Track finish_reason chunks for each index
finish_reason_chunks = {}
for chunk in chunks:
if chunk.choices:
for choice in chunk.choices:
if choice.finish_reason is not None:
index = choice.index
if index not in finish_reason_chunks:
finish_reason_chunks[index] = []
finish_reason_chunks[index].append(choice.finish_reason)
# Verify we got finish_reason chunks for both indices
self.assertEqual(
len(finish_reason_chunks),
2,
f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}",
)
# Verify both index 0 and 1 have finish_reason
self.assertIn(
0, finish_reason_chunks, "Missing finish_reason chunk for index 0"
)
self.assertIn(
1, finish_reason_chunks, "Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "tool_calls" since we forced tool calls
for index, reasons in finish_reason_chunks.items():
self.assertEqual(
reasons[-1], # Last finish_reason for this index
"tool_calls",
f"Expected finish_reason 'tool_calls' for index {index}, got {reasons[-1]}",
)
def test_function_calling_streaming_no_tool_call(self):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [{"role": "user", "content": "Who are you?"}]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
tool_choice="none",
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_tool_call = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
found_tool_call = choice.delta.tool_calls is not None
self.assertFalse(
found_tool_call,
"Shouldn't have any tool_call in the streaming chunks",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"stop",
"Final response of no function calling should have finish_reason 'stop'",
)
def test_streaming_multiple_choices_without_tools(self):
"""
Test: Verify that each choice gets its own finish_reason chunk without tool calls.
This tests the fix for regular content streaming with multiple choices.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [{"role": "user", "content": "Say hello in one word."}]
# Request with n=2 to get multiple choices, no tools
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
stream=True,
max_tokens=10, # Keep it short
n=2, # Multiple choices
)
chunks = list(response_stream)
# Track finish_reason chunks for each index
finish_reason_chunks = {}
for chunk in chunks:
if chunk.choices:
for choice in chunk.choices:
if choice.finish_reason is not None:
index = choice.index
if index not in finish_reason_chunks:
finish_reason_chunks[index] = []
finish_reason_chunks[index].append(choice.finish_reason)
# Verify we got finish_reason chunks for both indices
self.assertEqual(
len(finish_reason_chunks),
2,
f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}",
)
# Verify both index 0 and 1 have finish_reason
self.assertIn(
0, finish_reason_chunks, "Missing finish_reason chunk for index 0"
)
self.assertIn(
1, finish_reason_chunks, "Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "stop" (regular completion)
for index, reasons in finish_reason_chunks.items():
self.assertIn(
reasons[-1],
["stop", "length"], # Could be either depending on how model responds
f"Expected finish_reason 'stop' or 'length' for index {index}, got {reasons[-1]}",
)
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The name of the city or location.",
}
},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "get_tourist_attractions",
"description": "Get a list of top tourist attractions for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to find attractions for.",
}
},
"required": ["city"],
},
},
},
]
PYTHONIC_MESSAGES = [
{
"role": "system",
"content": (
"You are a travel assistant. "
"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, "
"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. "
"Do NOT use JSON, do NOT use variables, do NOT use any other format. "
"Here is an example:\n"
'[get_weather(location="Paris"), get_tourist_attractions(city="Paris")]'
),
},
{
"role": "user",
"content": (
"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? "
"Propose parallel tool calls at once, using the python list of function calls format as shown above."
),
},
]
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
cls.model = DEFAULT_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
"--tool-call-parser",
"pythonic",
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_pythonic_tool_call_prompt(self):
"""
Test: Explicit prompt for pythonic tool call format without chat template.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS,
temperature=0.1,
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsInstance(tool_calls, list, "No tool_calls found")
self.assertGreaterEqual(len(tool_calls), 1)
names = [tc.function.name for tc in tool_calls]
self.assertTrue(
"get_weather" in names or "get_tourist_attractions" in names,
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
def test_pythonic_tool_call_streaming(self):
"""
Test: Streaming pythonic tool call format; assert tool_call index is present.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response_stream = client.chat.completions.create(
model=self.model,
messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS,
temperature=0.1,
stream=True,
)
found_tool_calls = False
found_index = False
found_names = set()
for chunk in response_stream:
choice = chunk.choices[0]
if getattr(choice.delta, "tool_calls", None):
found_tool_calls = True
tool_call = choice.delta.tool_calls[0]
if hasattr(tool_call, "index") or (
isinstance(tool_call, dict) and "index" in tool_call
):
found_index = True
found_names.add(str(tool_call.function.name))
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
self.assertTrue(found_index, "No index field found in any streamed tool_call")
self.assertTrue(
"get_weather" in found_names or "get_tourist_attractions" in found_names,
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
if __name__ == "__main__":
unittest.main()
"""
Test script for tool_choice functionality in SGLang
Tests: required, auto, and specific function choices in both streaming and non-streaming modes
# To run the tests, use the following command:
#
# python3 -m unittest openai_server.function_call.test_tool_choice
"""
import json
# CHANGE: Import router launcher instead of server launcher
import sys
import unittest
from pathlib import Path
import openai
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH,
DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH,
DEFAULT_SMALL_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestToolChoiceLlama32(CustomTestCase):
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
# Mark flaky tests for this model
cls.flaky_tests = {
"test_multi_tool_scenario_auto",
"test_multi_tool_scenario_required",
}
# Use a model that supports function calling
cls.model = DEFAULT_SMALL_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Start the local OpenAI Server with tool calling support
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
"--tool-call-parser",
"llama", # Default parser for the test model
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def setUp(self):
self.client = openai.Client(base_url=self.base_url, api_key=self.api_key)
# TODO: Update the logic here when router /v1/models response format matching the openai api standard
self.model_name = self.client.models.list().models[0]
def _is_flaky_test(self):
"""Check if the current test is marked as flaky for this class"""
return (
hasattr(self.__class__, "flaky_tests")
and self._testMethodName in self.__class__.flaky_tests
)
def get_test_tools(self):
"""Get the test tools for function calling"""
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "get_pokemon_info",
"description": "get detailed information about a pokemon given its name",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "name of the pokemon to get info for",
}
},
"required": ["name"],
},
},
},
{
"type": "function",
"function": {
"name": "make_next_step_decision",
"description": "You will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools. \n You should never call the same tool with the same input twice in a row.\n If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again.\nOBSERVATION: the result of the tool call, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information,\n or you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\n If the previous conversation history already contains the answer, respond with the answer right away.\n\n If no tools are configured, naturally mention this limitation while still being helpful. Briefly note that adding tools in the agent configuration would expand capabilities.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.",
"parameters": {
"type": "object",
"properties": {
"decision": {
"type": "string",
"description": 'The next step to take, it must be either "TOOL" or "ANSWER". If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again. If there are no defined tools, you should not return "TOOL" in your response.',
},
"content": {
"type": "string",
"description": 'The content of the next step. If the decision is "TOOL", this should be a short and concise reasoning of why you chose the tool, MUST include the tool name. If the decision is "ANSWER", this should be the answer to the question. If no tools are available, integrate this limitation conversationally without sounding scripted.',
},
},
"required": ["decision", "content"],
},
},
},
]
def get_test_messages(self):
"""Get test messages that should trigger tool usage"""
return [
{
"role": "user",
"content": "Answer the following questions as best you can:\n\nYou will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools\nOBSERVATION: the result of the tool call or the observation of the current task, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information, \nif the previous conversation history already contains the answer, \nor you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\nYou may use light Markdown formatting to improve clarity (e.g. lists, **bold**, *italics*), but keep it minimal and unobtrusive.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.\n\nQuestion: what is the weather in top 5 populated cities in the US in celsius?\n\nTraces:\n\n\nThese are some additional instructions that you should follow:",
}
]
def get_travel_tools(self):
"""Get tools for travel assistant scenario that should trigger multiple tool calls"""
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The name of the city or location.",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "get_tourist_attractions",
"description": "Get a list of top tourist attractions for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to find attractions for.",
}
},
"required": ["city"],
},
},
},
]
def get_travel_messages(self):
"""Get travel assistant messages that should trigger multiple tool calls"""
return [
{
"content": "You are a travel assistant providing real-time weather updates and top tourist attractions.",
"role": "system",
},
{
"content": "I'm planning a trip to Tokyo next week. What's the weather like? What are the most amazing sights?",
"role": "user",
},
]
def test_tool_choice_auto_non_streaming(self):
"""Test tool_choice='auto' in non-streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice="auto",
stream=False,
)
self.assertIsNotNone(response.choices[0].message)
# With auto, tool calls are optional
def test_tool_choice_auto_streaming(self):
"""Test tool_choice='auto' in streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice="auto",
stream=True,
)
# Collect streaming response
content_chunks = []
tool_call_chunks = []
for chunk in response:
if chunk.choices[0].delta.content:
content_chunks.append(chunk.choices[0].delta.content)
elif chunk.choices[0].delta.tool_calls:
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
# Should complete without errors
self.assertIsInstance(content_chunks, list)
self.assertIsInstance(tool_call_chunks, list)
def test_tool_choice_required_non_streaming(self):
"""Test tool_choice='required' in non-streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
temperature=0.2,
tools=tools,
tool_choice="required",
stream=False,
)
# With required, we should get tool calls
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0)
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_required_streaming(self):
"""Test tool_choice='required' in streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice="required",
stream=True,
)
# Collect streaming response
tool_call_chunks = []
for chunk in response:
if chunk.choices[0].delta.tool_calls:
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
# With required, we should get tool call chunks
self.assertGreater(len(tool_call_chunks), 0)
def test_tool_choice_specific_function_non_streaming(self):
"""Test tool_choice with specific function in non-streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice=tool_choice,
stream=False,
)
# Should call the specific function
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls)
# Our messages ask the top 5 populated cities in the US, so the model could get 5 tool calls
self.assertGreaterEqual(len(tool_calls), 1)
for tool_call in tool_calls:
self.assertEqual(tool_call.function.name, "get_weather")
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_specific_function_streaming(self):
"""Test tool_choice with specific function in streaming mode"""
tools = self.get_test_tools()
messages = self.get_test_messages()
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice=tool_choice,
stream=True,
)
# Collect streaming response
tool_call_chunks = []
for chunk in response:
if chunk.choices[0].delta.tool_calls:
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
# Should get tool call chunks for the specific function
self.assertGreater(len(tool_call_chunks), 0)
# Find function name in chunks
found_name = None
for chunk in tool_call_chunks:
if chunk.function and chunk.function.name:
found_name = chunk.function.name
break
self.assertEqual(found_name, "get_weather")
@unittest.skip(
"Skipping required streaming arguments chunks json test as it is not supported by the router"
)
def test_required_streaming_arguments_chunks_json(self):
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=1024,
temperature=0.1,
tools=tools,
tool_choice="required",
stream=True,
)
# Collect all tool call chunks and reconstruct complete tool calls
tool_calls_by_index = {}
for chunk in response:
if chunk.choices[0].delta.tool_calls:
for tool_call_delta in chunk.choices[0].delta.tool_calls:
tool_index = tool_call_delta.index
# Initialize tool call if not seen before
if tool_index not in tool_calls_by_index:
tool_calls_by_index[tool_index] = {
"id": tool_call_delta.id,
"type": "function",
"function": {"name": "", "arguments": ""},
}
# Update function name if present (first chunk)
if tool_call_delta.function and tool_call_delta.function.name:
tool_calls_by_index[tool_index]["function"][
"name"
] = tool_call_delta.function.name
# Accumulate arguments (all chunks)
if tool_call_delta.function and tool_call_delta.function.arguments:
tool_calls_by_index[tool_index]["function"][
"arguments"
] += tool_call_delta.function.arguments
self.assertGreater(len(tool_calls_by_index), 0)
# Validate that complete tool calls have valid JSON arguments
for tool_call in tool_calls_by_index.values():
self.assertIsNotNone(tool_call["function"]["name"])
self.assertIsNotNone(tool_call["function"]["arguments"])
# The complete arguments should be valid JSON
try:
args = json.loads(tool_call["function"]["arguments"])
self.assertIsInstance(args, dict)
except json.JSONDecodeError:
self.fail(
f"Invalid JSON in complete tool call arguments: {tool_call['function']['arguments']}"
)
def test_complex_parameters_required_non_streaming(self):
"""Validate complex nested parameter schemas in non-streaming required mode"""
complex_tools = [
{
"type": "function",
"function": {
"name": "analyze_data",
"description": "Analyze complex data structures",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {"type": "string"},
},
"config": {
"type": "object",
"properties": {
"threshold": {"type": "number"},
"enabled": {"type": "boolean"},
},
},
},
"required": ["metrics"],
},
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "string"},
},
},
},
},
"required": ["data"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Analyze some data with metrics and configuration",
}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=1024,
temperature=0.1,
tools=complex_tools,
tool_choice="required",
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0)
for tool_call in tool_calls:
self.assertEqual(tool_call.function.name, "analyze_data")
try:
args = json.loads(tool_call.function.arguments)
self.assertIsInstance(args, dict)
self.assertIn("data", args)
self.assertIsInstance(args["data"], dict)
except json.JSONDecodeError:
self.fail(
f"Invalid JSON in complex tool call arguments: {tool_call.function.arguments}"
)
def test_multi_tool_scenario_auto(self):
"""Test multi-tool scenario with tool_choice='auto'"""
tools = self.get_travel_tools()
messages = self.get_travel_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
temperature=0.2,
tools=tools,
tool_choice="auto",
stream=False,
)
# Should complete without errors
self.assertIsNotNone(response.choices[0].message)
tool_calls = response.choices[0].message.tool_calls
expected_functions = {"get_weather", "get_tourist_attractions"}
if self._is_flaky_test():
# For flaky tests, just verify all called functions are available tools
if tool_calls:
available_names = [tool["function"]["name"] for tool in tools]
for call in tool_calls:
self.assertIn(call.function.name, available_names)
else:
# For non-flaky tests, enforce strict requirements
self.assertIsNotNone(tool_calls, "Expected tool calls but got none")
self.assertEqual(
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
)
called_functions = {call.function.name for call in tool_calls}
self.assertEqual(
called_functions,
expected_functions,
f"Expected functions {expected_functions}, got {called_functions}",
)
def test_multi_tool_scenario_required(self):
"""Test multi-tool scenario with tool_choice='required'"""
tools = self.get_travel_tools()
messages = self.get_travel_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
temperature=0.2,
tools=tools,
tool_choice="required",
stream=False,
)
# With required, we should get at least one tool call
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0)
# Verify all called functions are available tools
available_names = [tool["function"]["name"] for tool in tools]
expected_functions = {"get_weather", "get_tourist_attractions"}
for tool_call in tool_calls:
self.assertIsNotNone(tool_call.function.name)
self.assertIsNotNone(tool_call.function.arguments)
if self._is_flaky_test():
# For flaky tests, just ensure basic functionality works
self.assertGreater(
len(tool_calls),
0,
f"Expected at least 1 tool call, got {len(tool_calls)}",
)
for call in tool_calls:
self.assertIn(call.function.name, available_names)
else:
# For non-flaky tests, enforce strict requirements
self.assertEqual(
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
)
called_functions = {call.function.name for call in tool_calls}
self.assertEqual(
called_functions,
expected_functions,
f"Expected functions {expected_functions}, got {called_functions}",
)
def test_error_handling_invalid_tool_choice(self):
"""Test error handling for invalid tool_choice"""
tools = self.get_test_tools()
messages = self.get_test_messages()
# Test with invalid function name
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
# Expect a 400 BadRequestError to be raised for invalid tool_choice
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
tools=tools,
tool_choice=tool_choice,
stream=False,
)
# Verify the error message contains the expected text
self.assertIn(
"function 'nonexistent_function' not found in",
str(context.exception),
)
def test_invalid_tool_missing_name(self):
"""Test what happens when user doesn't provide a tool name in request"""
# Test with malformed JSON in tool parameters - missing required "name" field
invalid_tools = [
{
"type": "function",
"function": {
# Missing required "name" field
"description": "Test function with invalid schema",
"parameters": {
"type": "object",
"properties": {
"test_field": {
"type": "string",
"description": "Test field",
}
},
"required": ["test_field"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Test the function",
}
]
# Should raise BadRequestError due to missing required 'name' field
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=100,
temperature=0.1,
tools=invalid_tools,
tool_choice="required",
stream=False,
)
# Verify the error message indicates missing name field
error_msg = str(context.exception).lower()
self.assertIn("name", error_msg)
def test_conflicting_defs_required_tool_choice(self):
"""Test that conflicting $defs with required tool_choice returns 400 error"""
conflicting_tools = [
{
"type": "function",
"function": {
"name": "tool1",
"description": "Tool 1 with conflicting $defs",
"parameters": {
"type": "object",
"properties": {
"data": {"$ref": "#/$defs/DataType"},
},
"required": ["data"],
"$defs": {
"DataType": {
"type": "object",
"properties": {"value": {"type": "string"}},
"required": ["value"],
},
},
},
},
},
{
"type": "function",
"function": {
"name": "tool2",
"description": "Tool 2 with conflicting $defs",
"parameters": {
"type": "object",
"properties": {
"data": {"$ref": "#/$defs/DataType"},
},
"required": ["data"],
"$defs": {
"DataType": { # Different definition for DataType
"type": "object",
"properties": {"value": {"type": "number"}},
"required": ["value"],
},
},
},
},
},
]
messages = [
{
"role": "user",
"content": "Test the conflicting tools",
}
]
# Should raise BadRequestError due to conflicting $defs
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=100,
temperature=0.1,
tools=conflicting_tools,
tool_choice="required",
stream=False,
)
# Verify the error message indicates conflicting tool definitions
error_msg = str(context.exception).lower()
self.assertIn("invalid tool configuration", error_msg)
self.assertIn("not supported", error_msg)
class TestToolChoiceQwen25(TestToolChoiceLlama32):
"""Test tool_choice functionality with Qwen2.5 model"""
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
cls.flaky_tests = {}
cls.model = DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
"--tool-call-parser",
"qwen",
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
class TestToolChoiceMistral(TestToolChoiceLlama32):
"""Test tool_choice functionality with Mistral model"""
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
# Mark flaky tests for this model
cls.flaky_tests = {
"test_multi_tool_scenario_auto",
"test_multi_tool_scenario_required",
}
cls.model = DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
router_args=[
"--tool-call-parser",
"mistral",
],
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
def test_complex_parameters_required_non_streaming(self):
"""Validate complex nested parameter schemas in non-streaming required mode"""
super().test_complex_parameters_required_non_streaming()
if __name__ == "__main__":
unittest.main()
[pytest]
# Show print statements and logs
log_cli = true
log_cli_level = INFO
log_cli_format = %(asctime)s [%(levelname)8s] %(message)s
log_cli_date_format = %Y-%m-%d %H:%M:%S
# Show stdout/stderr
addopts = -v -s --tb=short
# Capture settings
# -s means don't capture stdout (show print statements)
# --tb=short means short traceback format
"""
Standalone utilities for e2e_grpc tests.
This module provides all necessary utilities without depending on sglang Python package.
Extracted and adapted from:
- sglang.srt.utils.kill_process_tree
- sglang.srt.utils.hf_transformers_utils.get_tokenizer
- sglang.test.test_utils (constants and CustomTestCase)
"""
import os
import signal
import threading
import unittest
from pathlib import Path
from typing import Optional, Union
import psutil
try:
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
except ImportError:
raise ImportError(
"transformers is required for tokenizer utilities. "
"Install with: pip install transformers"
)
# ============================================================================
# Constants
# ============================================================================
# Server and timeout constants
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 20000
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
# File name constants for test output
STDOUT_FILENAME = "/tmp/sglang_test_stdout.txt"
STDERR_FILENAME = "/tmp/sglang_test_stderr.txt"
# Model base path - can be overridden via environment variable
# By default, use HuggingFace model identifiers (no local path prefix)
# Set ROUTER_LOCAL_MODEL_PATH to use local models (e.g., "/home/ubuntu/models")
ROUTER_LOCAL_MODEL_PATH = os.environ.get("ROUTER_LOCAL_MODEL_PATH", "")
# Helper function to build model paths
def _get_model_path(model_identifier: str) -> str:
"""
Build model path from base path and model identifier.
If ROUTER_LOCAL_MODEL_PATH is set, prepend it to the identifier.
Otherwise, return the identifier as-is (for HuggingFace download).
"""
if ROUTER_LOCAL_MODEL_PATH:
return os.path.join(ROUTER_LOCAL_MODEL_PATH, model_identifier)
return model_identifier
# Model paths used in e2e_grpc tests
# These can be either HuggingFace identifiers or local paths (depending on ROUTER_LOCAL_MODEL_PATH)
# Main test model - Llama 3.1 8B Instruct
DEFAULT_MODEL_PATH = _get_model_path("meta-llama/Llama-3.1-8B-Instruct")
# Small models for function calling tests
DEFAULT_SMALL_MODEL_PATH = _get_model_path("meta-llama/Llama-3.2-1B-Instruct")
# Reasoning models
DEFAULT_REASONING_MODEL_PATH = _get_model_path(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
)
# Thinking-enabled models
DEFAULT_ENABLE_THINKING_MODEL_PATH = _get_model_path("Qwen/Qwen3-30B-A3B")
# Function calling models
DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = _get_model_path("Qwen/Qwen2.5-7B-Instruct")
DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = _get_model_path(
"mistralai/Mistral-7B-Instruct-v0.3"
)
# ============================================================================
# Process Management
# ============================================================================
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""
Kill the process and all its child processes.
Args:
parent_pid: PID of the parent process
include_parent: Whether to kill the parent process itself
skip_pid: Optional PID to skip during cleanup
"""
# Remove sigchld handler to avoid spammy logs
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
if parent_pid is None:
parent_pid = os.getpid()
include_parent = False
try:
itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
try:
child.kill()
except psutil.NoSuchProcess:
pass
if include_parent:
try:
itself.kill()
except psutil.NoSuchProcess:
pass
# ============================================================================
# Tokenizer Utilities
# ============================================================================
def check_gguf_file(model_path: str) -> bool:
"""Check if the model path points to a GGUF file."""
if not isinstance(model_path, str):
return False
return model_path.endswith(".gguf")
def is_remote_url(path: str) -> bool:
"""Check if the path is a remote URL."""
if not isinstance(path, str):
return False
return path.startswith("http://") or path.startswith("https://")
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""
Gets a tokenizer for the given model name via Huggingface.
Args:
tokenizer_name: Name or path of the tokenizer
tokenizer_mode: Mode for tokenizer loading ("auto", "slow")
trust_remote_code: Whether to trust remote code
tokenizer_revision: Specific revision to use
**kwargs: Additional arguments passed to AutoTokenizer.from_pretrained
Returns:
Loaded tokenizer instance
"""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
# Handle special model name mapping
if tokenizer_name == "mistralai/Devstral-Small-2505":
tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
# Note: Removed remote URL handling and local directory download
# as they depend on sglang-specific utilities
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
except TypeError as e:
# Handle specific errors
err_msg = (
"Failed to load the tokenizer. If you are running a model with "
"a custom tokenizer, please set the --trust-remote-code flag."
)
raise RuntimeError(err_msg) from e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
print(
f"Warning: Using a slow tokenizer. This might cause a performance "
f"degradation. Consider using a fast tokenizer instead."
)
return tokenizer
def get_tokenizer_from_processor(processor):
"""Extract tokenizer from a processor object."""
if isinstance(processor, PreTrainedTokenizerBase):
return processor
return processor.tokenizer
# ============================================================================
# Test Utilities
# ============================================================================
class CustomTestCase(unittest.TestCase):
"""
Custom test case base class with retry support.
This provides automatic test retry functionality based on environment variables.
"""
def _callTestMethod(self, method):
"""Override to add retry logic."""
max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0"))
if max_retry == 0:
# No retry, just run once
return super(CustomTestCase, self)._callTestMethod(method)
# Retry logic
for attempt in range(max_retry + 1):
try:
return super(CustomTestCase, self)._callTestMethod(method)
except Exception as e:
if attempt < max_retry:
print(
f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..."
)
continue
else:
# Last attempt, re-raise the exception
raise
def setUp(self):
"""Print test method name at the start of each test."""
print(f"[Test Method] {self._testMethodName}", flush=True)
"""
python3 -m unittest openai_server.validation.test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_completion
"""
import os
# CHANGE: Import router launcher instead of server launcher
import sys
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import openai
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestLargeMaxNewTokens(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
worker_args=(
"--max-total-token",
"1536",
"--context-len",
"8192",
"--decode-log-interval",
"2",
),
num_workers=1,
tp_size=2,
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
stdout=cls.stdout,
stderr=cls.stderr,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def run_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Please repeat the world 'hello' for 10000 times.",
},
],
temperature=0,
)
return response
def test_chat_completion(self):
num_requests = 4
all_requests_running = False
futures = []
with ThreadPoolExecutor(num_requests) as executor:
# Send multiple requests
for i in range(num_requests):
futures.append(executor.submit(self.run_chat_completion))
# Ensure that they are running concurrently
pt = 0
while pt >= 0:
time.sleep(5)
# Flush stderr to ensure logs are written
self.stderr.flush()
lines = open(STDERR_FILENAME).readlines()
for line in lines[pt:]:
print(line, end="", flush=True)
if f"#running-req: {num_requests}" in line:
all_requests_running = True
pt = -1
break
pt += 1
assert all_requests_running
if __name__ == "__main__":
unittest.main()
"""
gRPC Router E2E Test - Test Openai Server Ignore Eos
This test file is REUSED from test/srt/openai_server/validation/test_openai_server_ignore_eos.py
with minimal changes:
num_workers=2,
- Swap popen_launch_server() → popen_launch_workers_and_router()
- Update teardown to cleanup router + workers
- All test logic and assertions remain identical
Run with:
pytest py_test/e2e_grpc/e2e_grpc/validation/test_openai_server_ignore_eos.py -v
"""
# CHANGE: Import router launcher instead of server launcher
import sys
from pathlib import Path
import openai
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
get_tokenizer,
kill_process_tree,
)
class TestOpenAIServerIgnoreEOS(CustomTestCase):
@classmethod
def setUpClass(cls):
# CHANGE: Launch gRPC router with integrated workers (single command)
cls.model = DEFAULT_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
num_workers=1,
tp_size=2,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
# Cleanup router and workers
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_ignore_eos(self):
"""
Test that ignore_eos=True allows generation to continue beyond EOS token
and reach the max_tokens limit.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
max_tokens = 200
response_default = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count from 1 to 20."},
],
temperature=0,
max_tokens=max_tokens,
extra_body={"ignore_eos": False},
)
response_ignore_eos = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count from 1 to 20."},
],
temperature=0,
max_tokens=max_tokens,
extra_body={"ignore_eos": True},
)
default_tokens = len(
self.tokenizer.encode(response_default.choices[0].message.content)
)
ignore_eos_tokens = len(
self.tokenizer.encode(response_ignore_eos.choices[0].message.content)
)
# Check if ignore_eos resulted in more tokens or exactly max_tokens
# The ignore_eos response should either:
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
self.assertTrue(
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens,
f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}",
)
self.assertEqual(
response_ignore_eos.choices[0].finish_reason,
"length",
f"Expected finish_reason='length' for ignore_eos=True, got {response_ignore_eos.choices[0].finish_reason}",
)
...@@ -625,8 +625,6 @@ pub struct ChatCompletionMessage { ...@@ -625,8 +625,6 @@ pub struct ChatCompletionMessage {
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>, pub tool_calls: Option<Vec<ToolCall>>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
// Note: function_call is deprecated and not included // Note: function_call is deprecated and not included
// Note: refusal, annotations, audio are not added yet // Note: refusal, annotations, audio are not added yet
...@@ -669,8 +667,6 @@ pub struct ChatMessageDelta { ...@@ -669,8 +667,6 @@ pub struct ChatMessageDelta {
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>, pub tool_calls: Option<Vec<ToolCallDelta>>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment