"examples/pytorch/vscode:/vscode.git/clone" did not exist on "8b64ae59b8e02ccf35474e99ac63f5c6822b15d5"
Unverified Commit b47eda33 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

bugfix: Fix multiple finish_reason chunks and tool_calls finish reason check (#8417)

parent e983d666
......@@ -412,6 +412,8 @@ class OpenAIServingChat(OpenAIServingBase):
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
has_tool_calls = {}
finish_reasons = {}
# Usage tracking
prompt_tokens = {}
......@@ -443,6 +445,10 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
# Track finish_reason for each index
if finish_reason_type:
finish_reasons[index] = finish_reason
# First chunk with role
if is_firsts.get(index, True):
is_firsts[index] = False
......@@ -450,13 +456,8 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
finish_reason=None,
logprobs=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
......@@ -483,7 +484,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=finish_reason_type,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
......@@ -495,40 +496,34 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls
if request.tool_choice != "none" and request.tools:
async for (
chunk,
tool_call_finish_reason_type,
) in self._process_tool_call_stream(
async for chunk in self._process_tool_call_stream(
index,
delta,
parser_dict,
content,
request,
finish_reason_type,
has_tool_calls,
):
if chunk:
yield chunk
finish_reason_type = tool_call_finish_reason_type
# Send any remaining tool call arguments when generation finishes
if finish_reason_type is not None and index in parser_dict:
parser = parser_dict[index]
remaining_chunk = self._check_for_unstreamed_tool_args(
parser, content, request, index
)
if remaining_chunk:
yield remaining_chunk
else:
# Regular content
if delta or not (
request.stream_options and request.stream_options.include_usage
):
if delta:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
......@@ -539,26 +534,36 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Final chunk with finish_reason
finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send finish_reason chunks for each index that completed
for idx, finish_reason_data in finish_reasons.items():
finish_reason_type = finish_reason_data["type"]
# Change finish_reason to "tool_calls" if we had tool calls and stopped naturally
final_finish_reason = finish_reason_type
if has_tool_calls.get(idx, False) and finish_reason_type == "stop":
final_finish_reason = "tool_calls"
finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"][
"id"
], # NOTE: openai uses the same chatcmpl-id for all indices
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=idx,
delta=DeltaMessage(),
finish_reason=final_finish_reason,
matched_stop=(
finish_reason_data["matched"]
if "matched" in finish_reason_data
else None
),
)
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send hidden states if requested
if request.return_hidden_states and hidden_states:
......@@ -578,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
finish_reason=None, # Hidden states don't need finish_reason
)
],
model=request.model,
......@@ -857,7 +862,7 @@ class OpenAIServingChat(OpenAIServingBase):
parser_dict: Dict[int, FunctionCallParser],
content: Dict[str, Any],
request: ChatCompletionRequest,
finish_reason_type: Optional[str],
has_tool_calls: Dict[int, bool],
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
......@@ -874,7 +879,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=normal_text),
finish_reason=finish_reason_type,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
......@@ -882,10 +887,13 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
......@@ -896,23 +904,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_id = None
function_name = None
if finish_reason_type == "stop":
# Handle remaining arguments
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.detector.prev_tool_call_arr[index].get("arguments", {}),
ensure_ascii=False,
)
actual_call = parser.detector.streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(actual_call, "", 1)
call_item.parameters = remaining_call
finish_reason_type = "tool_calls"
tool_call = ToolCall(
id=tool_call_id,
index=call_item.tool_index,
......@@ -925,19 +916,84 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=(
None
if request.stream_options and request.stream_options.include_usage
else finish_reason_type
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
def _check_for_unstreamed_tool_args(
self,
parser: FunctionCallParser,
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
) -> Optional[str]:
"""
Check for any remaining tool call arguments that need to be streamed
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Only check if we have tool calls and the parser has tracked data
if (
not hasattr(parser.detector, "prev_tool_call_arr")
or not parser.detector.prev_tool_call_arr
):
return None
if (
not hasattr(parser.detector, "streamed_args_for_tool")
or not parser.detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (
expected_call.replace(actual_call, "", 1)
if actual_call in expected_call
else ""
)
if remaining_call:
# Create tool call chunk with remaining arguments
tool_call = ToolCall(
id=None, # No ID for argument deltas
index=tool_index,
function=FunctionResponse(
name=None, # No name for argument deltas
arguments=remaining_call,
),
)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=None, # Don't send finish_reason with this chunk
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
if finish_reason_type == "stop":
yield None, "tool_calls"
return f"data: {chunk.model_dump_json()}\n\n"
return None
......@@ -233,6 +233,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts = {}
is_finished = {}
finish_reason_counts = {}
for response in generator:
usage = response.usage
if usage is not None:
......@@ -245,6 +246,7 @@ class TestOpenAIServer(CustomTestCase):
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
data = response.choices[0].delta
......@@ -284,6 +286,15 @@ class TestOpenAIServer(CustomTestCase):
index, True
), f"index {index} is not found in the response"
# Verify that each choice gets exactly one finish_reason chunk
for index in range(parallel_sample_num):
assert (
index in finish_reason_counts
), f"No finish_reason found for index {index}"
assert (
finish_reason_counts[index] == 1
), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}"
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
......@@ -420,91 +431,6 @@ The SmartHome Mini is a compact smart home assistant available in black or white
client.models.retrieve("non-existent-model")
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
class TestOpenAIServerEBNF(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# passing xgrammar specifically
other_args = ["--grammar-backend", "xgrammar"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_ebnf(self):
"""
Ensure we can pass `ebnf` to the local openai server
and that it enforces the grammar.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "Hello" | "Hi" | "Hey"
"""
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful EBNF test bot."},
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
],
temperature=0,
max_tokens=32,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
def test_ebnf_strict_json(self):
"""
A stricter EBNF that produces exactly {"name":"Alice"} format
with no trailing punctuation or extra fields.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "{" pair "}"
pair ::= "\"name\"" ":" string
string ::= "\"" [A-Za-z]+ "\""
"""
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "EBNF mini-JSON generator."},
{
"role": "user",
"content": "Generate single key JSON with only letters.",
},
],
temperature=0,
max_tokens=64,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
self.assertRegex(
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
)
class TestOpenAIV1Rerank(CustomTestCase):
@classmethod
def setUpClass(cls):
......
......@@ -197,6 +197,134 @@ class ServingChatTestCase(unittest.TestCase):
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
async def test_unstreamed_tool_args_completion(self):
"""Test that remaining tool call arguments are sent when generation finishes."""
# Mock FunctionCallParser with detector that has partial tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was partially streamed
mock_detector.prev_tool_call_arr = [
{
"name": "get_weather",
"arguments": {"location": "San Francisco", "unit": "celsius"},
}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"' # Partial arguments streamed so far
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return a chunk with remaining arguments
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
self.assertIn('"arguments":', result, "Should contain arguments field")
self.assertIn(
', "unit": "celsius"}', result, "Should contain remaining arguments"
)
self.assertIn(
'"finish_reason":null',
result,
"Should not include finish_reason in completion chunk",
)
async def test_unstreamed_tool_args_no_completion_needed(self):
"""Test that no completion chunk is sent when all arguments were already streamed."""
# Mock FunctionCallParser with detector that has complete tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was completely streamed
mock_detector.prev_tool_call_arr = [
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"}' # All arguments already streamed
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since no completion is needed
self.assertIsNone(result, "Should return None when no completion is needed")
async def test_unstreamed_tool_args_no_parser_data(self):
"""Test that no completion chunk is sent when parser has no tool call data."""
# Mock FunctionCallParser with empty detector
mock_parser = Mock()
mock_detector = Mock()
mock_detector.prev_tool_call_arr = []
mock_detector.streamed_args_for_tool = []
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since there's no parser data
self.assertIsNone(
result, "Should return None when parser has no tool call data"
)
if __name__ == "__main__":
unittest.main(verbosity=2)
......@@ -16,6 +16,20 @@ from sglang.test.test_utils import (
class TestOpenAIServerFunctionCalling(CustomTestCase):
# NOTE: this system_message is for Llama3.2 system prompt. Without this,
# sometimes Llama3.2 gives a different tool call format such as:
# '<|python_tag|>{"type": "function", "function": "add", "parameters": {"a": "3", "b": "5"}}'
SYSTEM_MESSAGE = (
"You are a helpful assistant with tool calling capabilities. "
"Only reply with a tool call if the function exists in the library provided by the user. "
"If it doesn't exist, just reply directly in natural language. "
"When you receive a tool call response, use the output to format an answer to the original user question. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. '
"Do not use variables.\n\n"
)
@classmethod
def setUpClass(cls):
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......@@ -73,7 +87,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
}
]
messages = [{"role": "user", "content": "Compute (3+5)"}]
messages = [
{"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "Compute (3+5)"},
]
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
......@@ -205,7 +222,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
]
messages = [
{"role": "user", "content": "What is the temperature in Paris in celsius?"}
{"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "What is the temperature in Paris?"},
]
response_stream = client.chat.completions.create(
......@@ -248,74 +266,6 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"Final response of function calling should have finish_reason 'tool_calls'",
)
# TODO: There is a bug in sglang preventing this UT from passing. We are working on it. Once done, we will add this UT back.
def _test_function_calling_streaming_no_tool_call(self):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [{"role": "user", "content": "Who are you?"}]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
tool_choice="none",
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_tool_call = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
found_tool_call = choice.delta.tool_calls is not None
self.assertFalse(
found_tool_call,
"Shouldn't have any tool_call in the streaming chunks",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"stop",
"Final response of no function calling should have finish_reason 'stop'",
)
def test_function_calling_streaming_args_parsing(self):
"""
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
......@@ -350,7 +300,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
]
messages = [
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
{"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "Please sum 5 and 7, just call the function."},
]
response_stream = client.chat.completions.create(
......@@ -617,6 +568,212 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
def test_streaming_multiple_choices_finish_reason(self):
"""
Test: Verify that each choice gets its own finish_reason chunk in streaming mode with n > 1.
This tests the fix for the bug where only the last index got a finish_reason chunk.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
messages = [
{"role": "user", "content": "What is the weather like in Los Angeles?"}
]
# Request with n=2 to get multiple choices
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=2048,
temperature=0.8,
stream=True,
tools=tools,
tool_choice="required", # Force tool calls
n=2, # Multiple choices
)
chunks = list(response_stream)
# Track finish_reason chunks for each index
finish_reason_chunks = {}
for chunk in chunks:
if chunk.choices:
for choice in chunk.choices:
if choice.finish_reason is not None:
index = choice.index
if index not in finish_reason_chunks:
finish_reason_chunks[index] = []
finish_reason_chunks[index].append(choice.finish_reason)
# Verify we got finish_reason chunks for both indices
self.assertEqual(
len(finish_reason_chunks),
2,
f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}",
)
# Verify both index 0 and 1 have finish_reason
self.assertIn(
0, finish_reason_chunks, "Missing finish_reason chunk for index 0"
)
self.assertIn(
1, finish_reason_chunks, "Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "tool_calls" since we forced tool calls
for index, reasons in finish_reason_chunks.items():
self.assertEqual(
reasons[-1], # Last finish_reason for this index
"tool_calls",
f"Expected finish_reason 'tool_calls' for index {index}, got {reasons[-1]}",
)
def test_function_calling_streaming_no_tool_call(self):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [{"role": "user", "content": "Who are you?"}]
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
tool_choice="none",
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_tool_call = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
found_tool_call = choice.delta.tool_calls is not None
self.assertFalse(
found_tool_call,
"Shouldn't have any tool_call in the streaming chunks",
)
finish_reason = chunks[-1].choices[0].finish_reason
self.assertEqual(
finish_reason,
"stop",
"Final response of no function calling should have finish_reason 'stop'",
)
def test_streaming_multiple_choices_without_tools(self):
"""
Test: Verify that each choice gets its own finish_reason chunk without tool calls.
This tests the fix for regular content streaming with multiple choices.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [{"role": "user", "content": "Say hello in one word."}]
# Request with n=2 to get multiple choices, no tools
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=2048,
temperature=0.8,
stream=True,
max_tokens=10, # Keep it short
n=2, # Multiple choices
)
chunks = list(response_stream)
# Track finish_reason chunks for each index
finish_reason_chunks = {}
for chunk in chunks:
if chunk.choices:
for choice in chunk.choices:
if choice.finish_reason is not None:
index = choice.index
if index not in finish_reason_chunks:
finish_reason_chunks[index] = []
finish_reason_chunks[index].append(choice.finish_reason)
# Verify we got finish_reason chunks for both indices
self.assertEqual(
len(finish_reason_chunks),
2,
f"Expected finish_reason chunks for 2 indices, got {len(finish_reason_chunks)}",
)
# Verify both index 0 and 1 have finish_reason
self.assertIn(
0, finish_reason_chunks, "Missing finish_reason chunk for index 0"
)
self.assertIn(
1, finish_reason_chunks, "Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "stop" (regular completion)
for index, reasons in finish_reason_chunks.items():
self.assertIn(
reasons[-1],
["stop", "length"], # Could be either depending on how model responds
f"Expected finish_reason 'stop' or 'length' for index {index}, got {reasons[-1]}",
)
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [
......@@ -706,7 +863,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS,
temperature=0.1,
......@@ -728,7 +884,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response_stream = client.chat.completions.create(
model=self.model,
max_tokens=2048,
messages=self.PYTHONIC_MESSAGES,
tools=self.PYTHONIC_TOOLS,
temperature=0.1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment