Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61555307
Unverified
Commit
61555307
authored
Jul 11, 2025
by
Atream
Committed by
GitHub
Jul 11, 2025
Browse files
Support Kimi K2 (#7940)
parent
49a5915f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
480 additions
and
3 deletions
+480
-3
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+21
-0
python/sglang/srt/function_call/function_call_parser.py
python/sglang/srt/function_call/function_call_parser.py
+2
-0
python/sglang/srt/function_call/kimik2_detector.py
python/sglang/srt/function_call/kimik2_detector.py
+220
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+18
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-2
test/srt/test_function_call_parser.py
test/srt/test_function_call_parser.py
+209
-0
No files found.
docs/backend/server_arguments.md
View file @
61555307
...
@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
...
@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--file-storage-path`
| The path of the file storage in backend. | sglang_storage |
|
`--file-storage-path`
| The path of the file storage in backend. | sglang_storage |
|
`--enable-cache-report`
| Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
|
`--enable-cache-report`
| Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
|
`--reasoning-parser`
| Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
|
`--reasoning-parser`
| Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
|
`--tool-call-parser`
| Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3',
and
'pythonic'. | None |
|
`--tool-call-parser`
| Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic'
, and 'kimi_k2'
. | None |
## Data parallelism
## Data parallelism
...
...
python/sglang/srt/configs/model_config.py
View file @
61555307
...
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
...
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
(
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
get_config
,
get_context_length
,
get_context_length
,
get_generation_config
,
get_hf_text_config
,
get_hf_text_config
,
)
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
...
@@ -83,6 +84,13 @@ class ModelConfig:
...
@@ -83,6 +84,13 @@ class ModelConfig:
**
kwargs
,
**
kwargs
,
)
)
self
.
hf_generation_config
=
get_generation_config
(
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
attention_chunk_size
=
getattr
(
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
"attention_chunk_size"
,
None
self
.
hf_text_config
,
"attention_chunk_size"
,
None
...
@@ -467,6 +475,19 @@ class ModelConfig:
...
@@ -467,6 +475,19 @@ class ModelConfig:
if
eos_ids
:
if
eos_ids
:
# it can be either int or list of int
# it can be either int or list of int
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
if
eos_ids
is
None
:
eos_ids
=
set
()
if
self
.
hf_generation_config
:
generation_eos_ids
=
getattr
(
self
.
hf_generation_config
,
"eos_token_id"
,
None
)
if
generation_eos_ids
:
generation_eos_ids
=
(
{
generation_eos_ids
}
if
isinstance
(
generation_eos_ids
,
int
)
else
set
(
generation_eos_ids
)
)
eos_ids
=
eos_ids
|
generation_eos_ids
return
eos_ids
return
eos_ids
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
...
...
python/sglang/srt/function_call/function_call_parser.py
View file @
61555307
...
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
...
@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
ToolCallItem
from
sglang.srt.function_call.core_types
import
ToolCallItem
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.kimik2_detector
import
KimiK2Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
...
@@ -33,6 +34,7 @@ class FunctionCallParser:
...
@@ -33,6 +34,7 @@ class FunctionCallParser:
"mistral"
:
MistralDetector
,
"mistral"
:
MistralDetector
,
"deepseekv3"
:
DeepSeekV3Detector
,
"deepseekv3"
:
DeepSeekV3Detector
,
"pythonic"
:
PythonicDetector
,
"pythonic"
:
PythonicDetector
,
"kimi_k2"
:
KimiK2Detector
,
}
}
def
__init__
(
self
,
tools
:
List
[
Tool
],
tool_call_parser
:
str
):
def
__init__
(
self
,
tools
:
List
[
Tool
],
tool_call_parser
:
str
):
...
...
python/sglang/srt/function_call/kimik2_detector.py
0 → 100644
View file @
61555307
import
json
import
logging
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
StructureInfo
,
ToolCallItem
,
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.function_call.utils
import
_is_complete_json
logger
=
logging
.
getLogger
(
__name__
)
class
KimiK2Detector
(
BaseFormatDetector
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_buffer
=
""
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
list
[
str
]
=
(
[]
)
# map what has been streamed for each tool so far to a list
self
.
bot_token
:
str
=
"<|tool_calls_section_begin|>"
self
.
eot_token
:
str
=
"<|tool_calls_section_end|>"
self
.
tool_call_start_token
:
str
=
"<|tool_call_begin|>"
self
.
tool_call_end_token
:
str
=
"<|tool_call_end|>"
self
.
tool_call_regex
=
re
.
compile
(
r
"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
)
self
.
stream_tool_call_portion_regex
=
re
.
compile
(
r
"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
)
self
.
_last_arguments
=
""
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""Check if the text contains a KimiK2 format tool call."""
return
self
.
bot_token
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if
self
.
bot_token
not
in
text
:
return
StreamingParseResult
(
normal_text
=
text
,
calls
=
[])
try
:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples
=
self
.
tool_call_regex
.
findall
(
text
)
logger
.
debug
(
"function_call_tuples: %s"
,
function_call_tuples
)
tool_calls
=
[]
for
match
in
function_call_tuples
:
function_id
,
function_args
=
match
function_name
=
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
0
]
function_idx
=
int
(
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
1
])
logger
.
info
(
f
"function_name
{
function_name
}
"
)
tool_calls
.
append
(
ToolCallItem
(
tool_index
=
function_idx
,
# Use the call index in the response, not tool position
name
=
function_name
,
parameters
=
function_args
,
)
)
content
=
text
[:
text
.
find
(
self
.
bot_token
)]
return
StreamingParseResult
(
normal_text
=
content
,
calls
=
tool_calls
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in detect_and_parse:
{
e
}
"
)
# return the normal text if parsing fails
return
StreamingParseResult
(
normal_text
=
text
)
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing tool calls for KimiK2 format.
"""
self
.
_buffer
+=
new_text
current_text
=
self
.
_buffer
# Check if we have a tool call (either the start token or individual tool call)
has_tool_call
=
(
self
.
bot_token
in
current_text
or
self
.
tool_call_start_token
in
current_text
)
if
not
has_tool_call
:
self
.
_buffer
=
""
for
e_token
in
[
self
.
eot_token
,
self
.
tool_call_end_token
]:
if
e_token
in
new_text
:
new_text
=
new_text
.
replace
(
e_token
,
""
)
return
StreamingParseResult
(
normal_text
=
new_text
)
if
not
hasattr
(
self
,
"_tool_indices"
):
self
.
_tool_indices
=
{
tool
.
function
.
name
:
i
for
i
,
tool
in
enumerate
(
tools
)
if
tool
.
function
and
tool
.
function
.
name
}
calls
:
list
[
ToolCallItem
]
=
[]
try
:
match
=
self
.
stream_tool_call_portion_regex
.
search
(
current_text
)
if
match
:
function_id
=
match
.
group
(
"tool_call_id"
)
function_args
=
match
.
group
(
"function_arguments"
)
function_name
=
function_id
.
split
(
"."
)[
1
].
split
(
":"
)[
0
]
# Initialize state if this is the first tool call
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[
""
]
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
while
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_id
:
self
.
streamed_args_for_tool
.
append
(
""
)
if
not
self
.
current_tool_name_sent
:
calls
.
append
(
ToolCallItem
(
tool_index
=
self
.
current_tool_id
,
name
=
function_name
,
parameters
=
""
,
)
)
self
.
current_tool_name_sent
=
True
# Store the tool call info for adapter.py
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
function_name
,
"arguments"
:
{},
}
else
:
argument_diff
=
(
function_args
[
len
(
self
.
_last_arguments
)
:]
if
function_args
.
startswith
(
self
.
_last_arguments
)
else
function_args
)
parsed_args_diff
=
argument_diff
.
split
(
"<|tool_call_end|>"
,
1
)[
0
]
if
parsed_args_diff
:
calls
.
append
(
ToolCallItem
(
tool_index
=
self
.
current_tool_id
,
name
=
None
,
parameters
=
parsed_args_diff
,
)
)
self
.
_last_arguments
+=
argument_diff
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
parsed_args_diff
parsed_args
=
function_args
.
split
(
"<|tool_call_end|>"
,
1
)[
0
]
if
_is_complete_json
(
parsed_args
):
try
:
parsed_args
=
json
.
loads
(
parsed_args
)
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"arguments"
]
=
parsed_args
except
json
.
JSONDecodeError
:
pass
# Find the end of the current tool call and remove only that part from buffer
tool_call_end_pattern
=
(
r
"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
)
match
=
re
.
search
(
tool_call_end_pattern
,
current_text
,
re
.
DOTALL
)
if
match
:
# Remove the completed tool call from buffer, keep any remaining content
self
.
_buffer
=
current_text
[
match
.
end
()
:]
else
:
self
.
_buffer
=
""
result
=
StreamingParseResult
(
normal_text
=
""
,
calls
=
calls
)
self
.
current_tool_id
+=
1
self
.
_last_arguments
=
""
self
.
current_tool_name_sent
=
False
return
result
return
StreamingParseResult
(
normal_text
=
""
,
calls
=
calls
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in parse_streaming_increment:
{
e
}
"
)
return
StreamingParseResult
(
normal_text
=
current_text
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
raise
NotImplementedError
()
def
build_ebnf
(
self
,
tools
:
List
[
Tool
]):
raise
NotImplementedError
()
python/sglang/srt/hf_transformers_utils.py
View file @
61555307
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
logging
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -25,6 +26,7 @@ from transformers import (
...
@@ -25,6 +26,7 @@ from transformers import (
AutoConfig
,
AutoConfig
,
AutoProcessor
,
AutoProcessor
,
AutoTokenizer
,
AutoTokenizer
,
GenerationConfig
,
PretrainedConfig
,
PretrainedConfig
,
PreTrainedTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerBase
,
PreTrainedTokenizerBase
,
...
@@ -153,6 +155,22 @@ def get_config(
...
@@ -153,6 +155,22 @@ def get_config(
return
config
return
config
@
lru_cache_frozenset
(
maxsize
=
32
)
def
get_generation_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
try
:
return
GenerationConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
OSError
as
e
:
logging
.
info
(
"model doesn't have generation_config.json"
)
return
None
# Models don't use the same configuration key for determining the maximum
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/server_args.py
View file @
61555307
...
@@ -1048,9 +1048,16 @@ class ServerArgs:
...
@@ -1048,9 +1048,16 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--tool-call-parser"
,
"--tool-call-parser"
,
type
=
str
,
type
=
str
,
choices
=
[
"qwen25"
,
"mistral"
,
"llama3"
,
"deepseekv3"
,
"pythonic"
],
choices
=
[
"qwen25"
,
"mistral"
,
"llama3"
,
"deepseekv3"
,
"pythonic"
,
"kimi_k2"
,
],
default
=
ServerArgs
.
tool_call_parser
,
default
=
ServerArgs
.
tool_call_parser
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3',
and
'pythonic'."
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic'
, and 'kimi_k2'
."
,
)
)
# Data parallelism
# Data parallelism
...
...
test/srt/test_function_call_parser.py
View file @
61555307
...
@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
...
@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
from
sglang.srt.entrypoints.openai.protocol
import
Function
,
Tool
from
sglang.srt.entrypoints.openai.protocol
import
Function
,
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.kimik2_detector
import
KimiK2Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.mistral_detector
import
MistralDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
...
@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase):
...
@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase):
self
.
assertTrue
(
result
.
normal_text
.
strip
().
startswith
(
"Some intro."
))
self
.
assertTrue
(
result
.
normal_text
.
strip
().
startswith
(
"Some intro."
))
class
TestKimiK2Detector
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up test tools and detector."""
self
.
tools
=
[
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_weather"
,
description
=
"Get weather information"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_tourist_attractions"
,
description
=
"Get tourist attractions"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
]
self
.
detector
=
KimiK2Detector
()
def
test_single_tool_call
(
self
):
"""Test parsing a single tool call in a complete text."""
text
=
'<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
1
)
self
.
assertEqual
(
result
.
calls
[
0
].
name
,
"get_weather"
)
self
.
assertEqual
(
result
.
calls
[
0
].
parameters
,
'{"city": "Paris"}'
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_multiple_tool_calls
(
self
):
"""Test parsing multiple tool calls in a complete text."""
text
=
'<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
2
)
self
.
assertEqual
(
result
.
calls
[
0
].
name
,
"get_weather"
)
self
.
assertEqual
(
result
.
calls
[
0
].
parameters
,
'{"city": "Paris"}'
)
self
.
assertEqual
(
result
.
calls
[
1
].
name
,
"get_tourist_attractions"
)
self
.
assertEqual
(
result
.
calls
[
1
].
parameters
,
'{"city": "London"}'
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_streaming_tool_call
(
self
):
"""Test streaming incremental parsing of a tool call."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|><|tool_calls_section_end|>"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
1
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
def
test_streaming_multiple_tool_calls
(
self
):
"""Test streaming incremental parsing of multiple tool calls."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}<|tool_call_end|>"
,
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"
,
'"city": "London"'
,
"}<|tool_call_end|>"
,
"<|tool_calls_section_end|>"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
2
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
self
.
assertEqual
(
tool_calls
[
1
][
"name"
],
"get_tourist_attractions"
)
self
.
assertEqual
(
tool_calls
[
1
][
"parameters"
],
'{"city": "London"}'
)
def
test_tool_call_completion
(
self
):
"""Test that the buffer and state are reset after a tool call is completed."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|>"
,
"<|tool_calls_section_end|>"
,
]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
# After processing all chunks, the buffer should be empty and current_tool_id should be reset
self
.
assertEqual
(
self
.
detector
.
_buffer
,
""
)
self
.
assertEqual
(
self
.
detector
.
current_tool_id
,
1
)
def
test_tool_name_streaming
(
self
):
"""Test that tool names are streamed correctly with the right index."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
"}"
,
"<|tool_call_end|>"
,
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
2
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"}'
)
self
.
assertEqual
(
tool_calls
[
1
][
"name"
],
"get_tourist_attractions"
)
def
test_invalid_tool_call
(
self
):
"""Test that invalid tool calls are handled correctly."""
text
=
'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
0
)
self
.
assertEqual
(
result
.
normal_text
,
text
)
def
test_partial_tool_call
(
self
):
"""Test that partial tool calls are handled correctly in streaming mode."""
chunks
=
[
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"
,
'"city": "Paris"'
,
]
tool_calls
=
[]
for
chunk
in
chunks
:
result
=
self
.
detector
.
parse_streaming_increment
(
chunk
,
self
.
tools
)
for
tool_call_chunk
in
result
.
calls
:
if
tool_call_chunk
.
tool_index
is
not
None
:
while
len
(
tool_calls
)
<=
tool_call_chunk
.
tool_index
:
tool_calls
.
append
({
"name"
:
""
,
"parameters"
:
""
})
tc
=
tool_calls
[
tool_call_chunk
.
tool_index
]
if
tool_call_chunk
.
name
:
tc
[
"name"
]
+=
tool_call_chunk
.
name
if
tool_call_chunk
.
parameters
:
tc
[
"parameters"
]
+=
tool_call_chunk
.
parameters
self
.
assertEqual
(
len
(
tool_calls
),
1
)
self
.
assertEqual
(
tool_calls
[
0
][
"name"
],
"get_weather"
)
self
.
assertEqual
(
tool_calls
[
0
][
"parameters"
],
'{"city": "Paris"'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment