Unverified Commit e983d666 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

Fix: Improve test_openai_function_calling unit test and fix...


Fix: Improve test_openai_function_calling unit test and fix reasoning_parser.py think_start_token logic (#8316)
Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent b58c3c28
...@@ -493,9 +493,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -493,9 +493,6 @@ class OpenAIServingChat(OpenAIServingBase):
) )
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if not delta:
continue
# Handle tool calls # Handle tool calls
if request.tool_choice != "none" and request.tools: if request.tool_choice != "none" and request.tools:
async for ( async for (
......
...@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector: ...@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text. One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately. Returns both reasoning content and normal text separately.
""" """
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token) in_reasoning = self._in_reasoning or self.think_start_token in text
if not in_reasoning: if not in_reasoning:
return StreamingParseResult(normal_text=text) return StreamingParseResult(normal_text=text)
......
...@@ -76,6 +76,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -76,6 +76,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
messages = [{"role": "user", "content": "Compute (3+5)"}] messages = [{"role": "user", "content": "Compute (3+5)"}]
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.8, temperature=0.8,
top_p=0.8, top_p=0.8,
...@@ -92,6 +93,84 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -92,6 +93,84 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
function_name = tool_calls[0].function.name function_name = tool_calls[0].function.name
assert function_name == "add", "Function name should be 'add'" assert function_name == "add", "Function name should be 'add'"
# This unit test is too difficult for default model. Mark it as optional unit tests so it won't trigger unless specified.
def _test_function_calling_multiturn(self):
"""
Test: Whether the function call format returned by the AI is correct.
When returning a tool call, message.content should be None, and tool_calls should be a list.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "A number",
},
"b": {
"type": "int",
"description": "A number",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [{"role": "user", "content": "Compute (3+5)"}]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
assert function_name == "add", "Function name should be 'add'"
function_arguments = tool_call.function.arguments
function_arguments = json.loads(tool_call.function.arguments)
assert function_arguments in [
{"a": 3, "b": 5},
{"a": "3", "b": "5"},
], f"Unexpected function arguments: {function_arguments}"
messages.append(response.choices[0].message)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "8",
"name": function_name,
}
)
final_response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
assert (
"8" in final_response.choices[0].message.content
), "tool_call response should have the sum 8 in the content"
def test_function_calling_streaming_simple(self): def test_function_calling_streaming_simple(self):
""" """
Test: Whether the function name can be correctly recognized in streaming mode. Test: Whether the function name can be correctly recognized in streaming mode.
...@@ -125,10 +204,13 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -125,10 +204,13 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
} }
] ]
messages = [{"role": "user", "content": "What is the temperature in Paris?"}] messages = [
{"role": "user", "content": "What is the temperature in Paris in celsius?"}
]
response_stream = client.chat.completions.create( response_stream = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.8, temperature=0.8,
top_p=0.8, top_p=0.8,
...@@ -166,6 +248,74 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -166,6 +248,74 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"Final response of function calling should have finish_reason 'tool_calls'", "Final response of function calling should have finish_reason 'tool_calls'",
) )
# TODO: There is a bug in sglang preventing this UT from passing. We are working on it. Once done, we will add this UT back.
def _test_function_calling_streaming_no_tool_call(self):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [{"role": "user", "content": "Who are you?"}]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
tool_choice="none",
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_tool_call = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
found_tool_call = choice.delta.tool_calls is not None
self.assertFalse(
found_tool_call,
"Shouldn't have any tool_call in the streaming chunks",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"stop",
"Final response of no function calling should have finish_reason 'stop'",
)
def test_function_calling_streaming_args_parsing(self): def test_function_calling_streaming_args_parsing(self):
""" """
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
...@@ -205,6 +355,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -205,6 +355,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
response_stream = client.chat.completions.create( response_stream = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.9, temperature=0.9,
top_p=0.9, top_p=0.9,
...@@ -213,8 +364,9 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -213,8 +364,9 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
) )
argument_fragments = [] argument_fragments = []
chunks = list(response_stream)
function_name = None function_name = None
for chunk in response_stream: for chunk in chunks:
choice = chunk.choices[0] choice = chunk.choices[0]
if choice.delta.tool_calls: if choice.delta.tool_calls:
tool_call = choice.delta.tool_calls[0] tool_call = choice.delta.tool_calls[0]
...@@ -231,6 +383,13 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -231,6 +383,13 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"No parameter fragments were returned in the function call", "No parameter fragments were returned in the function call",
) )
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"tool_calls",
"Final response of function calling should have finish_reason 'tool_calls'",
)
# Check whether the concatenated JSON is valid # Check whether the concatenated JSON is valid
try: try:
args_obj = json.loads(joined_args) args_obj = json.loads(joined_args)
...@@ -281,6 +440,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -281,6 +440,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
] ]
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.8, temperature=0.8,
top_p=0.8, top_p=0.8,
...@@ -349,6 +509,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -349,6 +509,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
messages = [{"role": "user", "content": "What is the capital of France?"}] messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.8, temperature=0.8,
top_p=0.8, top_p=0.8,
...@@ -436,6 +597,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -436,6 +597,7 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
messages = [{"role": "user", "content": "What is the capital of France?"}] messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=messages, messages=messages,
temperature=0.8, temperature=0.8,
top_p=0.8, top_p=0.8,
...@@ -544,6 +706,7 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ...@@ -544,6 +706,7 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=self.PYTHONIC_MESSAGES, messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS, tools=self.PYTHONIC_TOOLS,
temperature=0.1, temperature=0.1,
...@@ -565,6 +728,7 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ...@@ -565,6 +728,7 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response_stream = client.chat.completions.create( response_stream = client.chat.completions.create(
model=self.model, model=self.model,
max_tokens=2048,
messages=self.PYTHONIC_MESSAGES, messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS, tools=self.PYTHONIC_TOOLS,
temperature=0.1, temperature=0.1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment