Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61555307
Unverified
Commit
61555307
authored
Jul 11, 2025
by
Atream
Committed by
GitHub
Jul 11, 2025
Browse files
Support Kimi K2 (#7940)
parent
49a5915f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
480 additions
and
3 deletions
+480
-3
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+21
-0
python/sglang/srt/function_call/function_call_parser.py
python/sglang/srt/function_call/function_call_parser.py
+2
-0
python/sglang/srt/function_call/kimik2_detector.py
python/sglang/srt/function_call/kimik2_detector.py
+220
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+18
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-2
test/srt/test_function_call_parser.py
test/srt/test_function_call_parser.py
+209
-0
No files found.
docs/backend/server_arguments.md
View file @
61555307
...
...
@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--file-storage-path`
| The path of the file storage in backend. | sglang_storage |
|
`--enable-cache-report`
| Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
|
`--reasoning-parser`
| Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
|
`--tool-call-parser`
| Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3',
and
'pythonic'. | None |
|
`--tool-call-parser`
| Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic'
, and 'kimi_k2'
. | None |
## Data parallelism
...
...
python/sglang/srt/configs/model_config.py
View file @
61555307
...
...
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
get_context_length
,
get_generation_config
,
get_hf_text_config
,
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
...
...
@@ -83,6 +84,13 @@ class ModelConfig:
**
kwargs
,
)
self
.
hf_generation_config
=
get_generation_config
(
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
"attention_chunk_size"
,
None
...
...
@@ -467,6 +475,19 @@ class ModelConfig:
if
eos_ids
:
# it can be either int or list of int
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
if
eos_ids
is
None
:
eos_ids
=
set
()
if
self
.
hf_generation_config
:
generation_eos_ids
=
getattr
(
self
.
hf_generation_config
,
"eos_token_id"
,
None
)
if
generation_eos_ids
:
generation_eos_ids
=
(
{
generation_eos_ids
}
if
isinstance
(
generation_eos_ids
,
int
)
else
set
(
generation_eos_ids
)
)
eos_ids
=
eos_ids
|
generation_eos_ids
return
eos_ids
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
...
...
python/sglang/srt/function_call/function_call_parser.py
View file @
61555307
...
...
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
ToolCallItem
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.kimik2_detector
import
KimiK2Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
...
...
@@ -33,6 +34,7 @@ class FunctionCallParser:
"mistral"
:
MistralDetector
,
"deepseekv3"
:
DeepSeekV3Detector
,
"pythonic"
:
PythonicDetector
,
"kimi_k2"
:
KimiK2Detector
,
}
def
__init__
(
self
,
tools
:
List
[
Tool
],
tool_call_parser
:
str
):
...
...
python/sglang/srt/function_call/kimik2_detector.py
0 → 100644
View file @
61555307
import
json
import
logging
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
StructureInfo
,
ToolCallItem
,
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.function_call.utils
import
_is_complete_json
logger
=
logging
.
getLogger
(
__name__
)
class
KimiK2Detector
(
BaseFormatDetector
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_buffer
=
""
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
list
[
str
]
=
(
[]
)
# map what has been streamed for each tool so far to a list
self
.
bot_token
:
str
=
"<|tool_calls_section_begin|>"
self
.
eot_token
:
str
=
"<|tool_calls_section_end|>"
self
.
tool_call_start_token
:
str
=
"<|tool_call_begin|>"
self
.
tool_call_end_token
:
str
=
"<|tool_call_end|>"
self
.
tool_call_regex
=
re
.
compile
(
r
"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
)
self
.
stream_tool_call_portion_regex
=
re
.
compile
(
r
"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
)
self
.
_last_arguments
=
""
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""Check if the text contains a KimiK2 format tool call."""
return
self
.
bot_token
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if
self
.
bot_token
not
in
text
:
return
StreamingParseResult
(
normal_text
=
text
,
calls
=
[])
try
:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples
=
self
.
tool_call_regex
.
findall
(
text
)
logger
.
debug
(
"function_call_tuples: %s"
,
function_call_tuples
)
tool_calls
=
[]
for
match
in
function_call_tuples
:
function_id
,
function_args
=
match
function_name
=
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
0
]
function_idx
=
int
(
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
1
])
logger
.
info
(
f
"function_name
{
function_name
}
"
)
tool_calls
.
append
(
ToolCallItem
(
tool_index
=
function_idx
,
# Use the call index in the response, not tool position
name
=
function_name
,
parameters
=
function_args
,
)
)
content
=
text
[:
text
.
find
(
self
.
bot_token
)]
return
StreamingParseResult
(
normal_text
=
content
,
calls
=
tool_calls
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in detect_and_parse:
{
e
}
"
)
# return the normal text if parsing fails
return
StreamingParseResult
(
normal_text
=
text
)
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing tool calls for KimiK2 format.
"""
self
.
_buffer
+=
new_text
current_text
=
self
.
_buffer
# Check if we have a tool call (either the start token or individual tool call)
has_tool_call
=
(
self
.
bot_token
in
current_text
or
self
.
tool_call_start_token
in
current_text
)
if
not
has_tool_call
:
self
.
_buffer
=
""
for
e_token
in
[
self
.
eot_token
,
self
.
tool_call_end_token
]:
if
e_token
in
new_text
:
new_text
=
new_text
.
replace
(
e_token
,
""
)
return
StreamingParseResult
(
normal_text
=
new_text
)
if
not
hasattr
(
self
,
"_tool_indices"
):
self
.
_tool_indices
=
{
tool
.
function
.
name
:
i
for
i
,
tool
in
enumerate
(
tools
)
if
tool
.
function
and
tool
.
function
.
name
}
calls
:
list
[
ToolCallItem
]
=
[]
try
:
match
=
self
.
stream_tool_call_portion_regex
.
search
(
current_text
)
if
match
:
function_id
=
match
.
group
(
"tool_call_id"
)
function_args
=
match
.
group
(
"function_arguments"
)
function_name
=
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
0
]
# Initialize state if this is the first tool call
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[
""
]
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
while
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_id
:
self
.
streamed_args_for_tool
.
append
(
""
)
if
not
self
.
current_tool_name_sent
:
calls
.
append
(
ToolCallItem
(
tool_index
=
self
.
current_tool_id
,
name
=
function_name
,
parameters
=
""
,
)
)
self
.
current_tool_name_sent
=
True
# Store the tool call info for adapter.py
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
function_name
,
"arguments"
:
{},
}
else
:
argument_diff
=
(
function_args
[
len
(
self
.
_last_arguments
)
:]
if
function_args
.
startswith
(
self
.
_last_arguments
)
else
function_args
)
parsed_args_diff
=
argument_diff
.
split
(
"<|tool_call_end|>"
,
1
)[
0
]
if
parsed_args_diff
:
calls
.
append
(
ToolCallItem
(
tool_index
=
self
.
current_tool_id
,
name
=
None
,
parameters
=
parsed_args_diff
,
)
)
self
.
_last_arguments
+=
argument_diff
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
parsed_args_diff
parsed_args
=
function_args
.
split
(
"<|tool_call_end|>"
,
1
)[
0
]
if
_is_complete_json
(
parsed_args
):
try
:
parsed_args
=
json
.
loads
(
parsed_args
)
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"arguments"
]
=
parsed_args
except
json
.
JSONDecodeError
:
pass
# Find the end of the current tool call and remove only that part from buffer
tool_call_end_pattern
=
(
r
"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
)
match
=
re
.
search
(
tool_call_end_pattern
,
current_text
,
re
.
DOTALL
)
if
match
:
# Remove the completed tool call from buffer, keep any remaining content
self
.
_buffer
=
current_text
[
match
.
end
()
:]
else
:
self
.
_buffer
=
""
result
=
StreamingParseResult
(
normal_text
=
""
,
calls
=
calls
)
self
.
current_tool_id
+=
1
self
.
_last_arguments
=
""
self
.
current_tool_name_sent
=
False
return
result
return
StreamingParseResult
(
normal_text
=
""
,
calls
=
calls
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in parse_streaming_increment:
{
e
}
"
)
return
StreamingParseResult
(
normal_text
=
current_text
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
raise
NotImplementedError
()
def
build_ebnf
(
self
,
tools
:
List
[
Tool
]):
raise
NotImplementedError
()
python/sglang/srt/hf_transformers_utils.py
View file @
61555307
...
...
@@ -14,6 +14,7 @@
"""Utilities for Huggingface Transformers."""
import
contextlib
import
logging
import
os
import
warnings
from
pathlib
import
Path
...
...
@@ -25,6 +26,7 @@ from transformers import (
AutoConfig
,
AutoProcessor
,
AutoTokenizer
,
GenerationConfig
,
PretrainedConfig
,
PreTrainedTokenizer
,
PreTrainedTokenizerBase
,
...
...
@@ -153,6 +155,22 @@ def get_config(
return
config
@
lru_cache_frozenset
(
maxsize
=
32
)
def
get_generation_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
try
:
return
GenerationConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
OSError
as
e
:
logging
.
info
(
"model doesn't have generation_config.json"
)
return
None
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/server_args.py
View file @
61555307
...
...
@@ -1048,9 +1048,16 @@ class ServerArgs:
parser
.
add_argument
(
"--tool-call-parser"
,
type
=
str
,
choices
=
[
"qwen25"
,
"mistral"
,
"llama3"
,
"deepseekv3"
,
"pythonic"
],
choices
=
[
"qwen25"
,
"mistral"
,
"llama3"
,
"deepseekv3"
,
"pythonic"
,
"kimi_k2"
,
],
default
=
ServerArgs
.
tool_call_parser
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3',
and
'pythonic'."
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic'
, and 'kimi_k2'
."
,
)
# Data parallelism
...
...
test/srt/test_function_call_parser.py
View file @
61555307
...
...
@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
from
sglang.srt.entrypoints.openai.protocol
import
Function
,
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.kimik2_detector
import
KimiK2Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
...
...
@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase):
self
.
assertTrue
(
result
.
normal_text
.
strip
().
startswith
(
"Some intro."
))
class
TestKimiK2Detector
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up test tools and detector."""
self
.
tools
=
[
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_weather"
,
description
=
"Get weather information"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_tourist_attractions"
,
description
=
"Get tourist attractions"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
]
self
.
detector
=
KimiK2Detector
()
def
test_single_tool_call
(
self
):
"""Test parsing a single tool call in a complete text."""
text
=
'<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
1
)
self
.
assertEqual
(
result
.
calls
[
0
].
name
,
"get_weather"
)
self
.
assertEqual
(
result
.
calls
[
0
].
parameters
,
'{"city": "Paris"}'
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_multiple_tool_calls
(
self
):
"""Test parsing multiple tool calls in a complete text."""
text
=
'<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
2
)
self
.
assertEqual
(
result
.
calls
[
0
].
name
,
"get_weather"
)
self
.
assertEqual
(
result
.
calls
[
0
].
parameters
,
'{"city": "Paris"}'
)
self
.
assertEqual
(
result
.
calls
[
1
].
name
,
"get_tourist_attractions"
)
self
.
assertEqual
(
result
.
calls
[
1
].
parameters
,
'{"city": "London"}'
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_streaming_tool_call
(
self
):
"""Test streaming incremental parsing of a tool call."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|><|tool_calls_section_end|>"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
1
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
def
test_streaming_multiple_tool_calls
(
self
):
"""Test streaming incremental parsing of multiple tool calls."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}<|tool_call_end|>"
,
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"
,
'"city": "London"'
,
"}<|tool_call_end|>"
,
"<|tool_calls_section_end|>"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
2
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
self
.
assertEqual
(
tool_calls
[
1
][
"name"
],
"get_tourist_attractions"
)
self
.
assertEqual
(
tool_calls
[
1
][
"parameters"
],
'{"city": "London"}'
)
def
test_tool_call_completion
(
self
):
"""Test that the buffer and state are reset after a tool call is completed."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|>"
,
"<|tool_calls_section_end|>"
,
]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
# After processing all chunks, the buffer should be empty and current_tool_id should be reset
self
.
assertEqual
(
self
.
detector
.
_buffer
,
""
)
self
.
assertEqual
(
self
.
detector
.
current_tool_id
,
1
)
def
test_tool_name_streaming
(
self
):
"""Test that tool names are streamed correctly with the right index."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|>"
,
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
2
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
self
.
assertEqual
(
tool_calls
[
1
][
"name"
],
"get_tourist_attractions"
)
def
test_invalid_tool_call
(
self
):
"""Test that invalid tool calls are handled correctly."""
text
=
'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
0
)
self
.
assertEqual
(
result
.
normal_text
,
text
)
def
test_partial_tool_call
(
self
):
"""Test that partial tool calls are handled correctly in streaming mode."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
1
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment