Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b47eda33
Unverified
Commit
b47eda33
authored
Jul 27, 2025
by
Chang Su
Committed by
GitHub
Jul 27, 2025
Browse files
bugfix: Fix multiple finish_reason chunks and tool_calls finish reason check (#8417)
parent
e983d666
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
499 additions
and
234 deletions
+499
-234
python/sglang/srt/entrypoints/openai/serving_chat.py
python/sglang/srt/entrypoints/openai/serving_chat.py
+132
-76
test/srt/openai_server/basic/test_openai_server.py
test/srt/openai_server/basic/test_openai_server.py
+11
-85
test/srt/openai_server/basic/test_serving_chat.py
test/srt/openai_server/basic/test_serving_chat.py
+128
-0
test/srt/openai_server/function_call/test_openai_function_calling.py
...enai_server/function_call/test_openai_function_calling.py
+228
-73
No files found.
python/sglang/srt/entrypoints/openai/serving_chat.py
View file @
b47eda33
...
@@ -412,6 +412,8 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -412,6 +412,8 @@ class OpenAIServingChat(OpenAIServingBase):
is_firsts
=
{}
is_firsts
=
{}
stream_buffers
=
{}
stream_buffers
=
{}
n_prev_tokens
=
{}
n_prev_tokens
=
{}
has_tool_calls
=
{}
finish_reasons
=
{}
# Usage tracking
# Usage tracking
prompt_tokens
=
{}
prompt_tokens
=
{}
...
@@ -443,6 +445,10 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -443,6 +445,10 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
]
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
]
finish_reason_type
=
finish_reason
[
"type"
]
if
finish_reason
else
None
finish_reason_type
=
finish_reason
[
"type"
]
if
finish_reason
else
None
# Track finish_reason for each index
if
finish_reason_type
:
finish_reasons
[
index
]
=
finish_reason
# First chunk with role
# First chunk with role
if
is_firsts
.
get
(
index
,
True
):
if
is_firsts
.
get
(
index
,
True
):
is_firsts
[
index
]
=
False
is_firsts
[
index
]
=
False
...
@@ -450,13 +456,8 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -450,13 +456,8 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
delta
,
delta
=
delta
,
finish_reason
=
finish_reason_type
,
finish_reason
=
None
,
matched_stop
=
(
logprobs
=
None
,
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
logprobs
=
choice_logprobs
,
)
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
id
=
content
[
"meta_info"
][
"id"
],
...
@@ -483,7 +484,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -483,7 +484,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
reasoning_content
=
reasoning_text
),
delta
=
DeltaMessage
(
reasoning_content
=
reasoning_text
),
finish_reason
=
finish_reason_typ
e
,
finish_reason
=
Non
e
,
)
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
id
=
content
[
"meta_info"
][
"id"
],
...
@@ -495,40 +496,34 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -495,40 +496,34 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls
# Handle tool calls
if
request
.
tool_choice
!=
"none"
and
request
.
tools
:
if
request
.
tool_choice
!=
"none"
and
request
.
tools
:
async
for
(
async
for
chunk
in
self
.
_process_tool_call_stream
(
chunk
,
tool_call_finish_reason_type
,
)
in
self
.
_process_tool_call_stream
(
index
,
index
,
delta
,
delta
,
parser_dict
,
parser_dict
,
content
,
content
,
request
,
request
,
finish_reason_type
,
has_tool_calls
,
):
):
if
chunk
:
if
chunk
:
yield
chunk
yield
chunk
finish_reason_type
=
tool_call_finish_reason_type
# Send any remaining tool call arguments when generation finishes
if
finish_reason_type
is
not
None
and
index
in
parser_dict
:
parser
=
parser_dict
[
index
]
remaining_chunk
=
self
.
_check_for_unstreamed_tool_args
(
parser
,
content
,
request
,
index
)
if
remaining_chunk
:
yield
remaining_chunk
else
:
else
:
# Regular content
# Regular content
if
delta
or
not
(
if
delta
:
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
content
=
delta
if
delta
else
None
),
delta
=
DeltaMessage
(
content
=
delta
if
delta
else
None
),
finish_reason
=
(
finish_reason
=
None
,
None
matched_stop
=
None
,
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
else
finish_reason_type
),
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
)
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
...
@@ -539,26 +534,36 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -539,26 +534,36 @@ class OpenAIServingChat(OpenAIServingBase):
)
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
# Final chunk with finish_reason
# Send finish_reason chunks for each index that completed
finish_reason_chunk
=
ChatCompletionStreamResponse
(
for
idx
,
finish_reason_data
in
finish_reasons
.
items
():
id
=
content
[
"meta_info"
][
"id"
],
finish_reason_type
=
finish_reason_data
[
"type"
]
created
=
int
(
time
.
time
()),
choices
=
[
# Change finish_reason to "tool_calls" if we had tool calls and stopped naturally
ChatCompletionResponseStreamChoice
(
final_finish_reason
=
finish_reason_type
index
=
index
,
if
has_tool_calls
.
get
(
idx
,
False
)
and
finish_reason_type
==
"stop"
:
delta
=
DeltaMessage
(),
final_finish_reason
=
"tool_calls"
finish_reason
=
finish_reason_type
,
matched_stop
=
(
finish_reason_chunk
=
ChatCompletionStreamResponse
(
finish_reason
[
"matched"
]
id
=
content
[
"meta_info"
][
if
finish_reason
and
"matched"
in
finish_reason
"id"
else
None
],
# NOTE: openai uses the same chatcmpl-id for all indices
),
created
=
int
(
time
.
time
()),
)
choices
=
[
],
ChatCompletionResponseStreamChoice
(
model
=
request
.
model
,
index
=
idx
,
usage
=
None
,
delta
=
DeltaMessage
(),
)
finish_reason
=
final_finish_reason
,
yield
f
"data:
{
finish_reason_chunk
.
model_dump_json
()
}
\n\n
"
matched_stop
=
(
finish_reason_data
[
"matched"
]
if
"matched"
in
finish_reason_data
else
None
),
)
],
model
=
request
.
model
,
usage
=
None
,
)
yield
f
"data:
{
finish_reason_chunk
.
model_dump_json
()
}
\n\n
"
# Send hidden states if requested
# Send hidden states if requested
if
request
.
return_hidden_states
and
hidden_states
:
if
request
.
return_hidden_states
and
hidden_states
:
...
@@ -578,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -578,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
delta
=
DeltaMessage
(
delta
=
DeltaMessage
(
hidden_states
=
last_token_hidden_states
hidden_states
=
last_token_hidden_states
),
),
finish_reason
=
finish_reason
_type
,
finish_reason
=
None
,
# Hidden states don't need
finish_reason
)
)
],
],
model
=
request
.
model
,
model
=
request
.
model
,
...
@@ -857,7 +862,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -857,7 +862,7 @@ class OpenAIServingChat(OpenAIServingBase):
parser_dict
:
Dict
[
int
,
FunctionCallParser
],
parser_dict
:
Dict
[
int
,
FunctionCallParser
],
content
:
Dict
[
str
,
Any
],
content
:
Dict
[
str
,
Any
],
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
finish_reason_type
:
Optional
[
str
],
has_tool_calls
:
Dict
[
int
,
bool
],
):
):
"""Process tool calls in streaming response"""
"""Process tool calls in streaming response"""
if
index
not
in
parser_dict
:
if
index
not
in
parser_dict
:
...
@@ -874,7 +879,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -874,7 +879,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
content
=
normal_text
),
delta
=
DeltaMessage
(
content
=
normal_text
),
finish_reason
=
finish_reason_typ
e
,
finish_reason
=
Non
e
,
)
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
id
=
content
[
"meta_info"
][
"id"
],
...
@@ -882,10 +887,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -882,10 +887,13 @@ class OpenAIServingChat(OpenAIServingBase):
choices
=
[
choice_data
],
choices
=
[
choice_data
],
model
=
request
.
model
,
model
=
request
.
model
,
)
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
,
finish_reason_type
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
# Yield tool calls
# Yield tool calls
for
call_item
in
calls
:
for
call_item
in
calls
:
# Mark that this choice has tool calls
has_tool_calls
[
index
]
=
True
# Tool call ID should be generated only once per tool call
# Tool call ID should be generated only once per tool call
if
call_item
.
name
:
if
call_item
.
name
:
# First chunk: include ID and function name
# First chunk: include ID and function name
...
@@ -896,23 +904,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -896,23 +904,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_id
=
None
tool_call_id
=
None
function_name
=
None
function_name
=
None
if
finish_reason_type
==
"stop"
:
# Handle remaining arguments
latest_delta_len
=
0
if
isinstance
(
call_item
.
parameters
,
str
):
latest_delta_len
=
len
(
call_item
.
parameters
)
expected_call
=
json
.
dumps
(
parser
.
detector
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}),
ensure_ascii
=
False
,
)
actual_call
=
parser
.
detector
.
streamed_args_for_tool
[
index
]
if
latest_delta_len
>
0
:
actual_call
=
actual_call
[:
-
latest_delta_len
]
remaining_call
=
expected_call
.
replace
(
actual_call
,
""
,
1
)
call_item
.
parameters
=
remaining_call
finish_reason_type
=
"tool_calls"
tool_call
=
ToolCall
(
tool_call
=
ToolCall
(
id
=
tool_call_id
,
id
=
tool_call_id
,
index
=
call_item
.
tool_index
,
index
=
call_item
.
tool_index
,
...
@@ -925,19 +916,84 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -925,19 +916,84 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
tool_calls
=
[
tool_call
]),
delta
=
DeltaMessage
(
tool_calls
=
[
tool_call
]),
finish_reason
=
(
finish_reason
=
None
,
None
)
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
chunk
=
ChatCompletionStreamResponse
(
else
finish_reason_type
id
=
content
[
"meta_info"
][
"id"
],
created
=
int
(
time
.
time
()),
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
def
_check_for_unstreamed_tool_args
(
self
,
parser
:
FunctionCallParser
,
content
:
Dict
[
str
,
Any
],
request
:
ChatCompletionRequest
,
index
:
int
,
)
->
Optional
[
str
]:
"""
Check for any remaining tool call arguments that need to be streamed
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Only check if we have tool calls and the parser has tracked data
if
(
not
hasattr
(
parser
.
detector
,
"prev_tool_call_arr"
)
or
not
parser
.
detector
.
prev_tool_call_arr
):
return
None
if
(
not
hasattr
(
parser
.
detector
,
"streamed_args_for_tool"
)
or
not
parser
.
detector
.
streamed_args_for_tool
):
return
None
# Get the last tool call that was being processed
tool_index
=
len
(
parser
.
detector
.
prev_tool_call_arr
)
-
1
if
tool_index
<
0
or
tool_index
>=
len
(
parser
.
detector
.
streamed_args_for_tool
):
return
None
# Get expected vs actual arguments
expected_args
=
parser
.
detector
.
prev_tool_call_arr
[
tool_index
].
get
(
"arguments"
,
{}
)
expected_call
=
json
.
dumps
(
expected_args
,
ensure_ascii
=
False
)
actual_call
=
parser
.
detector
.
streamed_args_for_tool
[
tool_index
]
# Check if there are remaining arguments to send
remaining_call
=
(
expected_call
.
replace
(
actual_call
,
""
,
1
)
if
actual_call
in
expected_call
else
""
)
if
remaining_call
:
# Create tool call chunk with remaining arguments
tool_call
=
ToolCall
(
id
=
None
,
# No ID for argument deltas
index
=
tool_index
,
function
=
FunctionResponse
(
name
=
None
,
# No name for argument deltas
arguments
=
remaining_call
,
),
),
)
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
tool_calls
=
[
tool_call
]),
finish_reason
=
None
,
# Don't send finish_reason with this chunk
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
id
=
content
[
"meta_info"
][
"id"
],
created
=
int
(
time
.
time
()),
created
=
int
(
time
.
time
()),
choices
=
[
choice_data
],
choices
=
[
choice_data
],
model
=
request
.
model
,
model
=
request
.
model
,
)
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
,
finish_reason_type
if
finish_reason_type
==
"stop"
:
return
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
yield
None
,
"tool_calls"
return
None
test/srt/openai_server/basic/test_openai_server.py
View file @
b47eda33
...
@@ -233,6 +233,7 @@ class TestOpenAIServer(CustomTestCase):
...
@@ -233,6 +233,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts
=
{}
is_firsts
=
{}
is_finished
=
{}
is_finished
=
{}
finish_reason_counts
=
{}
for
response
in
generator
:
for
response
in
generator
:
usage
=
response
.
usage
usage
=
response
.
usage
if
usage
is
not
None
:
if
usage
is
not
None
:
...
@@ -245,6 +246,7 @@ class TestOpenAIServer(CustomTestCase):
...
@@ -245,6 +246,7 @@ class TestOpenAIServer(CustomTestCase):
finish_reason
=
response
.
choices
[
0
].
finish_reason
finish_reason
=
response
.
choices
[
0
].
finish_reason
if
finish_reason
is
not
None
:
if
finish_reason
is
not
None
:
is_finished
[
index
]
=
True
is_finished
[
index
]
=
True
finish_reason_counts
[
index
]
=
finish_reason_counts
.
get
(
index
,
0
)
+
1
data
=
response
.
choices
[
0
].
delta
data
=
response
.
choices
[
0
].
delta
...
@@ -284,6 +286,15 @@ class TestOpenAIServer(CustomTestCase):
...
@@ -284,6 +286,15 @@ class TestOpenAIServer(CustomTestCase):
index
,
True
index
,
True
),
f
"index
{
index
}
is not found in the response"
),
f
"index
{
index
}
is not found in the response"
# Verify that each choice gets exactly one finish_reason chunk
for
index
in
range
(
parallel_sample_num
):
assert
(
index
in
finish_reason_counts
),
f
"No finish_reason found for index
{
index
}
"
assert
(
finish_reason_counts
[
index
]
==
1
),
f
"Expected 1 finish_reason chunk for index
{
index
}
, got
{
finish_reason_counts
[
index
]
}
"
def
test_completion
(
self
):
def
test_completion
(
self
):
for
echo
in
[
False
,
True
]:
for
echo
in
[
False
,
True
]:
for
logprobs
in
[
None
,
5
]:
for
logprobs
in
[
None
,
5
]:
...
@@ -420,91 +431,6 @@ The SmartHome Mini is a compact smart home assistant available in black or white
...
@@ -420,91 +431,6 @@ The SmartHome Mini is a compact smart home assistant available in black or white
client
.
models
.
retrieve
(
"non-existent-model"
)
client
.
models
.
retrieve
(
"non-existent-model"
)
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
class
TestOpenAIServerEBNF
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
# passing xgrammar specifically
other_args
=
[
"--grammar-backend"
,
"xgrammar"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
other_args
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_ebnf
(
self
):
"""
Ensure we can pass `ebnf` to the local openai server
and that it enforces the grammar.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
ebnf_grammar
=
r
"""
root ::= "Hello" | "Hi" | "Hey"
"""
pattern
=
re
.
compile
(
r
"^(Hello|Hi|Hey)[.!?]*\s*$"
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful EBNF test bot."
},
{
"role"
:
"user"
,
"content"
:
"Say a greeting (Hello, Hi, or Hey)."
},
],
temperature
=
0
,
max_tokens
=
32
,
extra_body
=
{
"ebnf"
:
ebnf_grammar
},
)
text
=
response
.
choices
[
0
].
message
.
content
.
strip
()
self
.
assertTrue
(
len
(
text
)
>
0
,
"Got empty text from EBNF generation"
)
self
.
assertRegex
(
text
,
pattern
,
f
"Text '
{
text
}
' doesn't match EBNF choices"
)
def
test_ebnf_strict_json
(
self
):
"""
A stricter EBNF that produces exactly {"name":"Alice"} format
with no trailing punctuation or extra fields.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
ebnf_grammar
=
r
"""
root ::= "{" pair "}"
pair ::= "\"name\"" ":" string
string ::= "\"" [A-Za-z]+ "\""
"""
pattern
=
re
.
compile
(
r
'^\{"name":"[A-Za-z]+"\}$'
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"EBNF mini-JSON generator."
},
{
"role"
:
"user"
,
"content"
:
"Generate single key JSON with only letters."
,
},
],
temperature
=
0
,
max_tokens
=
64
,
extra_body
=
{
"ebnf"
:
ebnf_grammar
},
)
text
=
response
.
choices
[
0
].
message
.
content
.
strip
()
self
.
assertTrue
(
len
(
text
)
>
0
,
"Got empty text from EBNF strict JSON test"
)
self
.
assertRegex
(
text
,
pattern
,
f
"Text '
{
text
}
' not matching the EBNF strict JSON shape"
)
class
TestOpenAIV1Rerank
(
CustomTestCase
):
class
TestOpenAIV1Rerank
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
test/srt/openai_server/basic/test_serving_chat.py
View file @
b47eda33
...
@@ -197,6 +197,134 @@ class ServingChatTestCase(unittest.TestCase):
...
@@ -197,6 +197,134 @@ class ServingChatTestCase(unittest.TestCase):
self
.
assertEqual
(
params
[
"min_new_tokens"
],
5
)
self
.
assertEqual
(
params
[
"min_new_tokens"
],
5
)
self
.
assertEqual
(
params
[
"stop"
],
[
"</s>"
])
self
.
assertEqual
(
params
[
"stop"
],
[
"</s>"
])
async
def
test_unstreamed_tool_args_completion
(
self
):
"""Test that remaining tool call arguments are sent when generation finishes."""
# Mock FunctionCallParser with detector that has partial tool call data
mock_parser
=
Mock
()
mock_detector
=
Mock
()
# Simulate a tool call that was partially streamed
mock_detector
.
prev_tool_call_arr
=
[
{
"name"
:
"get_weather"
,
"arguments"
:
{
"location"
:
"San Francisco"
,
"unit"
:
"celsius"
},
}
]
mock_detector
.
streamed_args_for_tool
=
[
'{"location": "San Francisco"'
# Partial arguments streamed so far
]
mock_parser
.
detector
=
mock_detector
content
=
{
"meta_info"
:
{
"id"
:
"chatcmpl-test123"
,
}
}
request
=
ChatCompletionRequest
(
model
=
"test"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"What's the weather?"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
}}],
)
# Test the completion method
result
=
self
.
chat
.
_check_for_unstreamed_tool_args
(
parser
=
mock_parser
,
content
=
content
,
request
=
request
,
finish_reason_type
=
"stop"
,
index
=
0
,
)
# Should return a chunk with remaining arguments
self
.
assertIsNotNone
(
result
,
"Should return chunk with remaining arguments"
)
self
.
assertIn
(
'"arguments":'
,
result
,
"Should contain arguments field"
)
self
.
assertIn
(
', "unit": "celsius"}'
,
result
,
"Should contain remaining arguments"
)
self
.
assertIn
(
'"finish_reason":null'
,
result
,
"Should not include finish_reason in completion chunk"
,
)
async
def
test_unstreamed_tool_args_no_completion_needed
(
self
):
"""Test that no completion chunk is sent when all arguments were already streamed."""
# Mock FunctionCallParser with detector that has complete tool call data
mock_parser
=
Mock
()
mock_detector
=
Mock
()
# Simulate a tool call that was completely streamed
mock_detector
.
prev_tool_call_arr
=
[
{
"name"
:
"get_weather"
,
"arguments"
:
{
"location"
:
"San Francisco"
}}
]
mock_detector
.
streamed_args_for_tool
=
[
'{"location": "San Francisco"}'
# All arguments already streamed
]
mock_parser
.
detector
=
mock_detector
content
=
{
"meta_info"
:
{
"id"
:
"chatcmpl-test123"
,
}
}
request
=
ChatCompletionRequest
(
model
=
"test"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"What's the weather?"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
}}],
)
# Test the completion method
result
=
self
.
chat
.
_check_for_unstreamed_tool_args
(
parser
=
mock_parser
,
content
=
content
,
request
=
request
,
finish_reason_type
=
"stop"
,
index
=
0
,
)
# Should return None since no completion is needed
self
.
assertIsNone
(
result
,
"Should return None when no completion is needed"
)
async
def
test_unstreamed_tool_args_no_parser_data
(
self
):
"""Test that no completion chunk is sent when parser has no tool call data."""
# Mock FunctionCallParser with empty detector
mock_parser
=
Mock
()
mock_detector
=
Mock
()
mock_detector
.
prev_tool_call_arr
=
[]
mock_detector
.
streamed_args_for_tool
=
[]
mock_parser
.
detector
=
mock_detector
content
=
{
"meta_info"
:
{
"id"
:
"chatcmpl-test123"
,
}
}
request
=
ChatCompletionRequest
(
model
=
"test"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"What's the weather?"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
}}],
)
# Test the completion method
result
=
self
.
chat
.
_check_for_unstreamed_tool_args
(
parser
=
mock_parser
,
content
=
content
,
request
=
request
,
finish_reason_type
=
"stop"
,
index
=
0
,
)
# Should return None since there's no parser data
self
.
assertIsNone
(
result
,
"Should return None when parser has no tool call data"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
unittest
.
main
(
verbosity
=
2
)
test/srt/openai_server/function_call/test_openai_function_calling.py
View file @
b47eda33
...
@@ -16,6 +16,20 @@ from sglang.test.test_utils import (
...
@@ -16,6 +16,20 @@ from sglang.test.test_utils import (
class
TestOpenAIServerFunctionCalling
(
CustomTestCase
):
class
TestOpenAIServerFunctionCalling
(
CustomTestCase
):
# NOTE: this system_message is for Llama3.2 system prompt. Without this,
# sometimes Llama3.2 gives a different tool call format such as:
# '<|python_tag|>{"type": "function", "function": "add", "parameters": {"a": "3", "b": "5"}}'
SYSTEM_MESSAGE
=
(
"You are a helpful assistant with tool calling capabilities. "
"Only reply with a tool call if the function exists in the library provided by the user. "
"If it doesn't exist, just reply directly in natural language. "
"When you receive a tool call response, use the output to format an answer to the original user question. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. '
"Do not use variables.
\n\n
"
)
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
@@ -73,7 +87,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
...
@@ -73,7 +87,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
}
}
]
]
messages
=
[{
"role"
:
"user"
,
"content"
:
"Compute (3+5)"
}]
messages
=
[
{
"role"
:
"system"
,
"content"
:
self
.
SYSTEM_MESSAGE
},
{
"role"
:
"user"
,
"content"
:
"Compute (3+5)"
},
]
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
max_tokens
=
2048
,
max_tokens
=
2048
,
...
@@ -205,7 +222,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
...
@@ -205,7 +222,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
]
]
messages
=
[
messages
=
[
{
"role"
:
"user"
,
"content"
:
"What is the temperature in Paris in celsius?"
}
{
"role"
:
"system"
,
"content"
:
self
.
SYSTEM_MESSAGE
},
{
"role"
:
"user"
,
"content"
:
"What is the temperature in Paris?"
},
]
]
response_stream
=
client
.
chat
.
completions
.
create
(
response_stream
=
client
.
chat
.
completions
.
create
(
...
@@ -248,74 +266,6 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
...
@@ -248,74 +266,6 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"Final response of function calling should have finish_reason 'tool_calls'"
,
"Final response of function calling should have finish_reason 'tool_calls'"
,
)
)
# TODO: There is a bug in sglang preventing this UT from passing. We are working on it. Once done, we will add this UT back.
def
_test_function_calling_streaming_no_tool_call
(
self
):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city to find the weather for"
,
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"Weather unit (celsius or fahrenheit)"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
],
},
},
"required"
:
[
"city"
,
"unit"
],
},
},
}
]
messages
=
[{
"role"
:
"user"
,
"content"
:
"Who are you?"
}]
response_stream
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
max_tokens
=
2048
,
messages
=
messages
,
temperature
=
0.8
,
top_p
=
0.8
,
stream
=
True
,
tools
=
tools
,
tool_choice
=
"none"
,
)
chunks
=
list
(
response_stream
)
self
.
assertTrue
(
len
(
chunks
)
>
0
,
"Streaming should return at least one chunk"
)
found_tool_call
=
False
for
chunk
in
chunks
:
choice
=
chunk
.
choices
[
0
]
# Check whether the current chunk contains tool_calls
found_tool_call
=
choice
.
delta
.
tool_calls
is
not
None
self
.
assertFalse
(
found_tool_call
,
"Shouldn't have any tool_call in the streaming chunks"
,
)
finish_reason
=
chunks
[
-
1
].
choices
[
0
].
finish_reason
self
.
assertEqual
(
finish_reason
,
"stop"
,
"Final response of no function calling should have finish_reason 'stop'"
,
)
def
test_function_calling_streaming_args_parsing
(
self
):
def
test_function_calling_streaming_args_parsing
(
self
):
"""
"""
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
...
@@ -350,7 +300,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
...
@@ -350,7 +300,8 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
]
]
messages
=
[
messages
=
[
{
"role"
:
"user"
,
"content"
:
"Please sum 5 and 7, just call the function."
}
{
"role"
:
"system"
,
"content"
:
self
.
SYSTEM_MESSAGE
},
{
"role"
:
"user"
,
"content"
:
"Please sum 5 and 7, just call the function."
},
]
]
response_stream
=
client
.
chat
.
completions
.
create
(
response_stream
=
client
.
chat
.
completions
.
create
(
...
@@ -617,6 +568,212 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
...
@@ -617,6 +568,212 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
)
)
self
.
assertIn
(
"city"
,
args_obj
,
"Function arguments should have 'city'"
)
self
.
assertIn
(
"city"
,
args_obj
,
"Function arguments should have 'city'"
)
def
test_streaming_multiple_choices_finish_reason
(
self
):
"""
Test: Verify that each choice gets its own finish_reason chunk in streaming mode with n > 1.
This tests the fix for the bug where only the last index got a finish_reason chunk.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
,
"description"
:
"The city and state, e.g. San Francisco, CA"
,
},
"unit"
:
{
"type"
:
"string"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
],
},
},
"required"
:
[
"location"
],
},
},
}
]
messages
=
[
{
"role"
:
"user"
,
"content"
:
"What is the weather like in Los Angeles?"
}
]
# Request with n=2 to get multiple choices
response_stream
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
messages
,
max_tokens
=
2048
,
temperature
=
0.8
,
stream
=
True
,
tools
=
tools
,
tool_choice
=
"required"
,
# Force tool calls
n
=
2
,
# Multiple choices
)
chunks
=
list
(
response_stream
)
# Track finish_reason chunks for each index
finish_reason_chunks
=
{}
for
chunk
in
chunks
:
if
chunk
.
choices
:
for
choice
in
chunk
.
choices
:
if
choice
.
finish_reason
is
not
None
:
index
=
choice
.
index
if
index
not
in
finish_reason_chunks
:
finish_reason_chunks
[
index
]
=
[]
finish_reason_chunks
[
index
].
append
(
choice
.
finish_reason
)
# Verify we got finish_reason chunks for both indices
self
.
assertEqual
(
len
(
finish_reason_chunks
),
2
,
f
"Expected finish_reason chunks for 2 indices, got
{
len
(
finish_reason_chunks
)
}
"
,
)
# Verify both index 0 and 1 have finish_reason
self
.
assertIn
(
0
,
finish_reason_chunks
,
"Missing finish_reason chunk for index 0"
)
self
.
assertIn
(
1
,
finish_reason_chunks
,
"Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "tool_calls" since we forced tool calls
for
index
,
reasons
in
finish_reason_chunks
.
items
():
self
.
assertEqual
(
reasons
[
-
1
],
# Last finish_reason for this index
"tool_calls"
,
f
"Expected finish_reason 'tool_calls' for index
{
index
}
, got
{
reasons
[
-
1
]
}
"
,
)
def
test_function_calling_streaming_no_tool_call
(
self
):
"""
Test: Whether the finish_reason is stop in streaming mode when no tool call is given.
- Expect no function call to be found.
- Verify that finish_reason is stop
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city to find the weather for"
,
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"Weather unit (celsius or fahrenheit)"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
],
},
},
"required"
:
[
"city"
,
"unit"
],
},
},
}
]
messages
=
[{
"role"
:
"user"
,
"content"
:
"Who are you?"
}]
response_stream
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
max_tokens
=
2048
,
messages
=
messages
,
temperature
=
0.8
,
top_p
=
0.8
,
stream
=
True
,
tools
=
tools
,
tool_choice
=
"none"
,
)
chunks
=
list
(
response_stream
)
self
.
assertTrue
(
len
(
chunks
)
>
0
,
"Streaming should return at least one chunk"
)
found_tool_call
=
False
for
chunk
in
chunks
:
choice
=
chunk
.
choices
[
0
]
# Check whether the current chunk contains tool_calls
found_tool_call
=
choice
.
delta
.
tool_calls
is
not
None
self
.
assertFalse
(
found_tool_call
,
"Shouldn't have any tool_call in the streaming chunks"
,
)
finish_reason
=
chunks
[
-
1
].
choices
[
0
].
finish_reason
self
.
assertEqual
(
finish_reason
,
"stop"
,
"Final response of no function calling should have finish_reason 'stop'"
,
)
def
test_streaming_multiple_choices_without_tools
(
self
):
"""
Test: Verify that each choice gets its own finish_reason chunk without tool calls.
This tests the fix for regular content streaming with multiple choices.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"Say hello in one word."
}]
# Request with n=2 to get multiple choices, no tools
response_stream
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
messages
,
max_tokens
=
2048
,
temperature
=
0.8
,
stream
=
True
,
max_tokens
=
10
,
# Keep it short
n
=
2
,
# Multiple choices
)
chunks
=
list
(
response_stream
)
# Track finish_reason chunks for each index
finish_reason_chunks
=
{}
for
chunk
in
chunks
:
if
chunk
.
choices
:
for
choice
in
chunk
.
choices
:
if
choice
.
finish_reason
is
not
None
:
index
=
choice
.
index
if
index
not
in
finish_reason_chunks
:
finish_reason_chunks
[
index
]
=
[]
finish_reason_chunks
[
index
].
append
(
choice
.
finish_reason
)
# Verify we got finish_reason chunks for both indices
self
.
assertEqual
(
len
(
finish_reason_chunks
),
2
,
f
"Expected finish_reason chunks for 2 indices, got
{
len
(
finish_reason_chunks
)
}
"
,
)
# Verify both index 0 and 1 have finish_reason
self
.
assertIn
(
0
,
finish_reason_chunks
,
"Missing finish_reason chunk for index 0"
)
self
.
assertIn
(
1
,
finish_reason_chunks
,
"Missing finish_reason chunk for index 1"
)
# Verify the finish_reason is "stop" (regular completion)
for
index
,
reasons
in
finish_reason_chunks
.
items
():
self
.
assertIn
(
reasons
[
-
1
],
[
"stop"
,
"length"
],
# Could be either depending on how model responds
f
"Expected finish_reason 'stop' or 'length' for index
{
index
}
, got
{
reasons
[
-
1
]
}
"
,
)
class
TestOpenAIPythonicFunctionCalling
(
CustomTestCase
):
class
TestOpenAIPythonicFunctionCalling
(
CustomTestCase
):
PYTHONIC_TOOLS
=
[
PYTHONIC_TOOLS
=
[
...
@@ -706,7 +863,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
...
@@ -706,7 +863,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
max_tokens
=
2048
,
messages
=
self
.
PYTHONIC_MESSAGES
,
messages
=
self
.
PYTHONIC_MESSAGES
,
tools
=
self
.
PYTHONIC_TOOLS
,
tools
=
self
.
PYTHONIC_TOOLS
,
temperature
=
0.1
,
temperature
=
0.1
,
...
@@ -728,7 +884,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
...
@@ -728,7 +884,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response_stream
=
client
.
chat
.
completions
.
create
(
response_stream
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
max_tokens
=
2048
,
messages
=
self
.
PYTHONIC_MESSAGES
,
messages
=
self
.
PYTHONIC_MESSAGES
,
tools
=
self
.
PYTHONIC_TOOLS
,
tools
=
self
.
PYTHONIC_TOOLS
,
temperature
=
0.1
,
temperature
=
0.1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment