Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
816c4c85
Unverified
Commit
816c4c85
authored
Aug 21, 2025
by
Chang Su
Committed by
GitHub
Aug 21, 2025
Browse files
[router] add tool parser base structure and partial json parser (#9482)
parent
13ec8d42
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1260 additions
and
22 deletions
+1260
-22
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+1
-0
sgl-router/benches/tokenizer_benchmark.rs
sgl-router/benches/tokenizer_benchmark.rs
+22
-22
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+1
-0
sgl-router/src/tool_parser/errors.rs
sgl-router/src/tool_parser/errors.rs
+32
-0
sgl-router/src/tool_parser/mod.rs
sgl-router/src/tool_parser/mod.rs
+20
-0
sgl-router/src/tool_parser/partial_json.rs
sgl-router/src/tool_parser/partial_json.rs
+527
-0
sgl-router/src/tool_parser/registry.rs
sgl-router/src/tool_parser/registry.rs
+119
-0
sgl-router/src/tool_parser/state.rs
sgl-router/src/tool_parser/state.rs
+181
-0
sgl-router/src/tool_parser/tests.rs
sgl-router/src/tool_parser/tests.rs
+249
-0
sgl-router/src/tool_parser/traits.rs
sgl-router/src/tool_parser/traits.rs
+35
-0
sgl-router/src/tool_parser/types.rs
sgl-router/src/tool_parser/types.rs
+73
-0
No files found.
sgl-router/Cargo.toml
View file @
816c4c85
...
@@ -48,6 +48,7 @@ metrics = "0.24.2"
...
@@ -48,6 +48,7 @@ metrics = "0.24.2"
metrics-exporter-prometheus
=
"0.17.0"
metrics-exporter-prometheus
=
"0.17.0"
uuid
=
{
version
=
"1.10"
,
features
=
[
"v4"
,
"serde"
]
}
uuid
=
{
version
=
"1.10"
,
features
=
[
"v4"
,
"serde"
]
}
thiserror
=
"2.0.12"
thiserror
=
"2.0.12"
regex
=
"1.10"
url
=
"2.5.4"
url
=
"2.5.4"
tokio-stream
=
{
version
=
"0.1"
,
features
=
["sync"]
}
tokio-stream
=
{
version
=
"0.1"
,
features
=
["sync"]
}
anyhow
=
"1.0"
anyhow
=
"1.0"
...
...
sgl-router/benches/tokenizer_benchmark.rs
View file @
816c4c85
...
@@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) {
...
@@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) {
let
tokenizer_clone
=
tokenizer
.clone
();
let
tokenizer_clone
=
tokenizer
.clone
();
// Get token count once
// Get token count once
let
token_count
=
tokenizer
.encode
(
prompt
)
.unwrap
()
.token_ids
()
.len
();
let
encoding
=
tokenizer
.encode
(
prompt
)
.unwrap
();
let
token_count
=
encoding
.token_ids
()
.len
();
// Track if metrics have been printed for this test case
// Track if metrics have been printed for this test case
let
printed
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
printed
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
...
@@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) {
...
@@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) {
let
batch_sizes
=
vec!
[
1
,
8
,
16
,
32
,
64
,
128
];
let
batch_sizes
=
vec!
[
1
,
8
,
16
,
32
,
64
,
128
];
let
prompt
=
MEDIUM_PROMPT
;
let
prompt
=
MEDIUM_PROMPT
;
let
prompt_len
=
prompt
.len
();
let
prompt_len
=
prompt
.len
();
let
token_count
=
tokenizer
.encode
(
prompt
)
.unwrap
()
.token_ids
()
.len
();
let
encoding
=
tokenizer
.encode
(
prompt
)
.unwrap
();
let
token_count
=
encoding
.token_ids
()
.len
();
let
mut
group
=
c
.benchmark_group
(
"batch_encode"
);
let
mut
group
=
c
.benchmark_group
(
"batch_encode"
);
...
@@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) {
...
@@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) {
);
);
let
test_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
10
);
let
test_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
10
);
let
tokens
=
tokenizer
.encode
(
&
test_text
)
.unwrap
()
.token_ids
();
let
encoding
=
tokenizer
.encode
(
&
test_text
)
.unwrap
();
let
tokens
=
encoding
.token_ids
();
let
num_tokens
=
tokens
.len
();
let
num_tokens
=
tokens
.len
();
let
mut
group
=
c
.benchmark_group
(
"decode_performance"
);
let
mut
group
=
c
.benchmark_group
(
"decode_performance"
);
...
@@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) {
...
@@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) {
group
.bench_function
(
"direct_decode"
,
|
b
|
{
group
.bench_function
(
"direct_decode"
,
|
b
|
{
let
printed
=
printed_direct
.clone
();
let
printed
=
printed_direct
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
tokens
.clone
();
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
iters
{
for
_
in
0
..
iters
{
black_box
(
tokenizer
.decode
(
&
tokens
,
false
)
.unwrap
());
black_box
(
tokenizer
.decode
(
tokens
,
false
)
.unwrap
());
}
}
let
duration
=
start
.elapsed
();
let
duration
=
start
.elapsed
();
...
@@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) {
...
@@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) {
group
.bench_function
(
"decode_stream"
,
|
b
|
{
group
.bench_function
(
"decode_stream"
,
|
b
|
{
let
printed
=
printed_stream
.clone
();
let
printed
=
printed_stream
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
tokens
.clone
();
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
iters
{
for
_
in
0
..
iters
{
let
mut
decoder
=
DecodeStream
::
new
(
tokenizer
.clone
(),
&
[],
false
);
let
mut
decoder
=
DecodeStream
::
new
(
tokenizer
.clone
(),
&
[],
false
);
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
for
token
in
&
tokens
{
for
token
in
tokens
{
if
let
Some
(
text
)
=
decoder
.step
(
*
token
)
.unwrap
()
{
if
let
Some
(
text
)
=
decoder
.step
(
*
token
)
.unwrap
()
{
output
.push_str
(
&
text
);
output
.push_str
(
&
text
);
}
}
...
@@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) {
...
@@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) {
group
.bench_function
(
"sequence_decode"
,
|
b
|
{
group
.bench_function
(
"sequence_decode"
,
|
b
|
{
let
printed
=
printed_seq
.clone
();
let
printed
=
printed_seq
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
tokens
.clone
();
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
iters
{
for
_
in
0
..
iters
{
let
mut
sequence
=
Sequence
::
new
(
tokenizer
.clone
());
let
mut
sequence
=
Sequence
::
new
(
tokenizer
.clone
());
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
for
token
in
&
tokens
{
for
token
in
tokens
{
let
text
=
sequence
.append_token
(
*
token
)
.unwrap
();
let
text
=
sequence
.append_token
(
*
token
)
.unwrap
();
output
.push_str
(
&
text
);
output
.push_str
(
&
text
);
}
}
...
@@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
...
@@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
);
);
let
sample_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
1000
);
let
sample_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
1000
);
let
all_tokens
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
()
.token_ids
();
let
encoding
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
();
let
all_tokens
=
encoding
.token_ids
();
let
mut
group
=
c
.benchmark_group
(
"streaming_100k"
);
let
mut
group
=
c
.benchmark_group
(
"streaming_100k"
);
group
.measurement_time
(
Duration
::
from_secs
(
1
));
group
.measurement_time
(
Duration
::
from_secs
(
1
));
...
@@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
...
@@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
group
.bench_function
(
"decode_stream_100k"
,
|
b
|
{
group
.bench_function
(
"decode_stream_100k"
,
|
b
|
{
let
printed
=
printed_stream
.clone
();
let
printed
=
printed_stream
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
all_tokens
.clone
();
b
.iter_custom
(|
_
iters
|
{
b
.iter_custom
(|
_
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
...
@@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
...
@@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
let
mut
tokens_processed
=
0u64
;
let
mut
tokens_processed
=
0u64
;
for
token
in
tokens
.iter
()
.cycle
()
{
for
token
in
all_
tokens
.iter
()
.cycle
()
{
if
start
.elapsed
()
>=
Duration
::
from_millis
(
500
)
{
if
start
.elapsed
()
>=
Duration
::
from_millis
(
500
)
{
break
;
break
;
}
}
...
@@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
...
@@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
group
.bench_function
(
"sequence_100k"
,
|
b
|
{
group
.bench_function
(
"sequence_100k"
,
|
b
|
{
let
printed
=
printed_seq
.clone
();
let
printed
=
printed_seq
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
all_tokens
.clone
();
b
.iter_custom
(|
_
iters
|
{
b
.iter_custom
(|
_
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
...
@@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
...
@@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
let
mut
tokens_processed
=
0u64
;
let
mut
tokens_processed
=
0u64
;
for
token
in
tokens
.iter
()
.cycle
()
{
for
token
in
all_
tokens
.iter
()
.cycle
()
{
if
start
.elapsed
()
>=
Duration
::
from_millis
(
500
)
{
if
start
.elapsed
()
>=
Duration
::
from_millis
(
500
)
{
break
;
break
;
}
}
...
@@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) {
...
@@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) {
let
tokens_per_sequence
=
10_000
;
let
tokens_per_sequence
=
10_000
;
let
sample_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
100
);
let
sample_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
100
);
let
token_batch
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
()
.token_ids
();
let
encoding
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
();
let
token_batch
:
Vec
<
u32
>
=
encoding
.token_ids
()
.to_vec
();
let
mut
group
=
c
.benchmark_group
(
"concurrent_streaming"
);
let
mut
group
=
c
.benchmark_group
(
"concurrent_streaming"
);
group
.measurement_time
(
Duration
::
from_secs
(
2
));
group
.measurement_time
(
Duration
::
from_secs
(
2
));
...
@@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) {
...
@@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) {
.with_stop_token
(
2
);
.with_stop_token
(
2
);
let
sample_text
=
"Hello world! This is a test. ### Stop here. Continue after."
.repeat
(
100
);
let
sample_text
=
"Hello world! This is a test. ### Stop here. Continue after."
.repeat
(
100
);
let
tokens
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
()
.token_ids
();
let
encoding
=
tokenizer
.encode
(
&
sample_text
)
.unwrap
();
let
tokens
=
encoding
.token_ids
();
let
mut
group
=
c
.benchmark_group
(
"stop_sequences"
);
let
mut
group
=
c
.benchmark_group
(
"stop_sequences"
);
...
@@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
...
@@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
group
.bench_function
(
"no_stops"
,
|
b
|
{
group
.bench_function
(
"no_stops"
,
|
b
|
{
let
printed_clone
=
printed_no_stop
.clone
();
let
printed_clone
=
printed_no_stop
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
tokens
.clone
();
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
...
@@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
...
@@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
StopSequenceConfig
::
default
(),
StopSequenceConfig
::
default
(),
false
,
false
,
);
);
for
token
in
&
tokens
{
for
token
in
tokens
{
let
_
=
decoder
.process_token
(
*
token
)
.unwrap
();
let
_
=
decoder
.process_token
(
*
token
)
.unwrap
();
total_tokens
+=
1
;
total_tokens
+=
1
;
}
}
...
@@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
...
@@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
group
.bench_function
(
"with_stops"
,
|
b
|
{
group
.bench_function
(
"with_stops"
,
|
b
|
{
let
printed_clone
=
printed_with_stops
.clone
();
let
printed_clone
=
printed_with_stops
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokenizer
=
tokenizer
.clone
();
let
tokens
=
tokens
.clone
();
let
config
=
config
.clone
();
let
config
=
config
.clone
();
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
...
@@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
...
@@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
StopSequenceDecoder
::
new
(
tokenizer
.clone
(),
config
.clone
(),
false
);
StopSequenceDecoder
::
new
(
tokenizer
.clone
(),
config
.clone
(),
false
);
let
mut
sequence_tokens
=
0u64
;
let
mut
sequence_tokens
=
0u64
;
for
token
in
&
tokens
{
for
token
in
tokens
{
let
result
=
decoder
.process_token
(
*
token
)
.unwrap
();
let
result
=
decoder
.process_token
(
*
token
)
.unwrap
();
sequence_tokens
+=
1
;
sequence_tokens
+=
1
;
...
@@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) {
...
@@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) {
// Generate tokens for decoding
// Generate tokens for decoding
let
test_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
100
);
let
test_text
=
"The quick brown fox jumps over the lazy dog. "
.repeat
(
100
);
let
test_tokens
=
tokenizer
.encode
(
&
test_text
)
.unwrap
()
.token_ids
();
let
encoding
=
tokenizer
.encode
(
&
test_text
)
.unwrap
();
let
test_tokens
:
Vec
<
u32
>
=
encoding
.token_ids
()
.to_vec
();
let
mut
group
=
c
.benchmark_group
(
"multithreaded_decode"
);
let
mut
group
=
c
.benchmark_group
(
"multithreaded_decode"
);
group
.measurement_time
(
Duration
::
from_secs
(
2
));
group
.measurement_time
(
Duration
::
from_secs
(
2
));
...
@@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) {
...
@@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) {
b
.iter_custom
(|
iters
|
{
b
.iter_custom
(|
iters
|
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
iters
{
for
_
in
0
..
iters
{
let
_
=
black_box
(
encoding
.token_ids
_ref
());
let
_
=
black_box
(
encoding
.token_ids
());
}
}
let
duration
=
start
.elapsed
();
let
duration
=
start
.elapsed
();
...
...
sgl-router/src/lib.rs
View file @
816c4c85
...
@@ -14,6 +14,7 @@ pub mod routers;
...
@@ -14,6 +14,7 @@ pub mod routers;
pub
mod
server
;
pub
mod
server
;
pub
mod
service_discovery
;
pub
mod
service_discovery
;
pub
mod
tokenizer
;
pub
mod
tokenizer
;
pub
mod
tool_parser
;
pub
mod
tree
;
pub
mod
tree
;
use
crate
::
metrics
::
PrometheusConfig
;
use
crate
::
metrics
::
PrometheusConfig
;
...
...
sgl-router/src/tool_parser/errors.rs
0 → 100644
View file @
816c4c85
use
thiserror
::
Error
;
/// Result type for tool parser operations
pub
type
ToolParserResult
<
T
>
=
Result
<
T
,
ToolParserError
>
;
/// Errors that can occur during tool parsing
#[derive(Debug,
Error)]
pub
enum
ToolParserError
{
#[error(
"Parsing failed: {0}"
)]
ParsingFailed
(
String
),
#[error(
"Model not supported: {0}"
)]
ModelNotSupported
(
String
),
#[error(
"Parse depth exceeded: max {0}"
)]
DepthExceeded
(
usize
),
#[error(
"Invalid JSON: {0}"
)]
JsonError
(
#[from]
serde_json
::
Error
),
#[error(
"Regex error: {0}"
)]
RegexError
(
#[from]
regex
::
Error
),
#[error(
"Incomplete tool call"
)]
Incomplete
,
#[error(
"Invalid tool name: {0}"
)]
InvalidToolName
(
String
),
#[error(
"Token not found: {0}"
)]
TokenNotFound
(
String
),
}
sgl-router/src/tool_parser/mod.rs
0 → 100644
View file @
816c4c85
/// Tool parser module for handling function/tool calls in model outputs
///
/// This module provides infrastructure for parsing tool calls from various model formats.
/// Phase 1 focuses on core infrastructure: types, traits, registry, and partial JSON parsing.
pub
mod
errors
;
pub
mod
partial_json
;
pub
mod
registry
;
pub
mod
state
;
pub
mod
traits
;
pub
mod
types
;
#[cfg(test)]
mod
tests
;
// Re-export commonly used types
pub
use
errors
::{
ToolParserError
,
ToolParserResult
};
pub
use
registry
::
ParserRegistry
;
pub
use
state
::{
ParsePhase
,
ParseState
};
pub
use
traits
::{
PartialJsonParser
,
ToolParser
};
pub
use
types
::{
FunctionCall
,
PartialToolCall
,
StreamResult
,
TokenConfig
,
ToolCall
};
sgl-router/src/tool_parser/partial_json.rs
0 → 100644
View file @
816c4c85
use
crate
::
tool_parser
::{
errors
::{
ToolParserError
,
ToolParserResult
},
traits
::
PartialJsonParser
,
};
use
serde_json
::{
Map
,
Value
};
/// Parser for incomplete JSON
pub
struct
PartialJson
{
/// Maximum depth for nested structures
max_depth
:
usize
,
/// Whether to allow incomplete values
allow_incomplete
:
bool
,
}
impl
PartialJson
{
/// Create a new partial JSON parser
pub
fn
new
(
max_depth
:
usize
,
allow_incomplete
:
bool
)
->
Self
{
Self
{
max_depth
,
allow_incomplete
,
}
}
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
pub
fn
parse_value
(
&
self
,
input
:
&
str
)
->
ToolParserResult
<
(
Value
,
usize
)
>
{
let
mut
parser
=
Parser
::
new
(
input
,
self
.max_depth
,
self
.allow_incomplete
);
let
value
=
parser
.parse_value
(
0
)
?
;
Ok
((
value
,
parser
.position
))
}
}
impl
Default
for
PartialJson
{
fn
default
()
->
Self
{
Self
::
new
(
32
,
true
)
}
}
impl
PartialJsonParser
for
PartialJson
{
fn
parse
(
&
self
,
input
:
&
str
)
->
ToolParserResult
<
(
Value
,
usize
)
>
{
self
.parse_value
(
input
)
}
fn
is_complete
(
&
self
,
input
:
&
str
)
->
bool
{
// Try to parse as complete JSON
serde_json
::
from_str
::
<
Value
>
(
input
)
.is_ok
()
}
fn
max_depth
(
&
self
)
->
usize
{
self
.max_depth
}
}
/// Internal parser state
struct
Parser
<
'a
>
{
chars
:
std
::
iter
::
Peekable
<
std
::
str
::
Chars
<
'a
>>
,
position
:
usize
,
max_depth
:
usize
,
allow_incomplete
:
bool
,
}
impl
<
'a
>
Parser
<
'a
>
{
fn
new
(
input
:
&
'a
str
,
max_depth
:
usize
,
allow_incomplete
:
bool
)
->
Self
{
Self
{
chars
:
input
.chars
()
.peekable
(),
position
:
0
,
max_depth
,
allow_incomplete
,
}
}
fn
peek
(
&
mut
self
)
->
Option
<
char
>
{
self
.chars
.peek
()
.copied
()
}
fn
advance
(
&
mut
self
)
{
if
self
.chars
.next
()
.is_some
()
{
self
.position
+=
1
;
}
}
fn
skip_whitespace
(
&
mut
self
)
{
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_whitespace
()
{
self
.advance
();
}
else
{
break
;
}
}
}
fn
parse_value
(
&
mut
self
,
depth
:
usize
)
->
ToolParserResult
<
Value
>
{
if
depth
>
self
.max_depth
{
return
Err
(
ToolParserError
::
DepthExceeded
(
self
.max_depth
));
}
self
.skip_whitespace
();
match
self
.peek
()
{
Some
(
'{'
)
=>
self
.parse_object
(
depth
+
1
),
Some
(
'['
)
=>
self
.parse_array
(
depth
+
1
),
Some
(
'"'
)
=>
self
.parse_string
(),
Some
(
't'
)
|
Some
(
'f'
)
=>
self
.parse_bool
(),
Some
(
'n'
)
=>
self
.parse_null
(),
Some
(
c
)
if
c
==
'-'
||
c
.is_ascii_digit
()
=>
self
.parse_number
(),
_
=>
{
if
self
.allow_incomplete
{
Ok
(
Value
::
Null
)
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Unexpected character"
.into
(),
))
}
}
}
}
fn
parse_object
(
&
mut
self
,
depth
:
usize
)
->
ToolParserResult
<
Value
>
{
if
depth
>
self
.max_depth
{
return
Err
(
ToolParserError
::
DepthExceeded
(
self
.max_depth
));
}
let
mut
object
=
Map
::
new
();
// Consume '{'
self
.advance
();
self
.skip_whitespace
();
// Check for empty object
if
self
.peek
()
==
Some
(
'}'
)
{
self
.advance
();
return
Ok
(
Value
::
Object
(
object
));
}
loop
{
// Parse key
let
key
=
match
self
.parse_string
()
{
Ok
(
Value
::
String
(
s
))
=>
s
,
Err
(
_
)
if
self
.allow_incomplete
=>
{
// Incomplete object
return
Ok
(
Value
::
Object
(
object
));
}
Err
(
e
)
=>
return
Err
(
e
),
_
=>
return
Err
(
ToolParserError
::
ParsingFailed
(
"Expected string key"
.into
())),
};
self
.skip_whitespace
();
// Expect ':'
if
self
.peek
()
!=
Some
(
':'
)
{
if
self
.allow_incomplete
{
// Add null value for incomplete pair
object
.insert
(
key
,
Value
::
Null
);
return
Ok
(
Value
::
Object
(
object
));
}
return
Err
(
ToolParserError
::
ParsingFailed
(
"Expected ':'"
.into
()));
}
self
.advance
();
self
.skip_whitespace
();
// Parse value (keep same depth - we already incremented in parse_object)
let
value
=
match
self
.parse_value
(
depth
)
{
Ok
(
v
)
=>
v
,
Err
(
_
)
if
self
.allow_incomplete
=>
{
// Add null for incomplete value
object
.insert
(
key
,
Value
::
Null
);
return
Ok
(
Value
::
Object
(
object
));
}
Err
(
e
)
=>
return
Err
(
e
),
};
object
.insert
(
key
,
value
);
self
.skip_whitespace
();
match
self
.peek
()
{
Some
(
','
)
=>
{
self
.advance
();
self
.skip_whitespace
();
// Check for trailing comma
if
self
.peek
()
==
Some
(
'}'
)
{
self
.advance
();
return
Ok
(
Value
::
Object
(
object
));
}
}
Some
(
'}'
)
=>
{
self
.advance
();
return
Ok
(
Value
::
Object
(
object
));
}
None
if
self
.allow_incomplete
=>
{
return
Ok
(
Value
::
Object
(
object
));
}
_
=>
{
if
self
.allow_incomplete
{
return
Ok
(
Value
::
Object
(
object
));
}
return
Err
(
ToolParserError
::
ParsingFailed
(
"Expected ',' or '}'"
.into
()));
}
}
}
}
fn
parse_array
(
&
mut
self
,
depth
:
usize
)
->
ToolParserResult
<
Value
>
{
if
depth
>
self
.max_depth
{
return
Err
(
ToolParserError
::
DepthExceeded
(
self
.max_depth
));
}
let
mut
array
=
Vec
::
new
();
// Consume '['
self
.advance
();
self
.skip_whitespace
();
// Check for empty array
if
self
.peek
()
==
Some
(
']'
)
{
self
.advance
();
return
Ok
(
Value
::
Array
(
array
));
}
loop
{
// Parse value (keep same depth - we already incremented in parse_object)
let
value
=
match
self
.parse_value
(
depth
)
{
Ok
(
v
)
=>
v
,
Err
(
_
)
if
self
.allow_incomplete
=>
{
return
Ok
(
Value
::
Array
(
array
));
}
Err
(
e
)
=>
return
Err
(
e
),
};
array
.push
(
value
);
self
.skip_whitespace
();
match
self
.peek
()
{
Some
(
','
)
=>
{
self
.advance
();
self
.skip_whitespace
();
// Check for trailing comma
if
self
.peek
()
==
Some
(
']'
)
{
self
.advance
();
return
Ok
(
Value
::
Array
(
array
));
}
}
Some
(
']'
)
=>
{
self
.advance
();
return
Ok
(
Value
::
Array
(
array
));
}
None
if
self
.allow_incomplete
=>
{
return
Ok
(
Value
::
Array
(
array
));
}
_
=>
{
if
self
.allow_incomplete
{
return
Ok
(
Value
::
Array
(
array
));
}
return
Err
(
ToolParserError
::
ParsingFailed
(
"Expected ',' or ']'"
.into
()));
}
}
}
}
fn
parse_string
(
&
mut
self
)
->
ToolParserResult
<
Value
>
{
if
self
.peek
()
!=
Some
(
'"'
)
{
return
Err
(
ToolParserError
::
ParsingFailed
(
"Expected '
\"
'"
.into
()));
}
// Consume opening quote
self
.advance
();
let
mut
string
=
String
::
new
();
let
mut
escaped
=
false
;
while
let
Some
(
ch
)
=
self
.peek
()
{
if
escaped
{
// Handle escape sequences
let
escaped_char
=
match
ch
{
'"'
|
'\\'
|
'/'
=>
ch
,
'b'
=>
'\
u
{
0008
}
'
,
'f'
=>
'\
u
{
000
C
}
'
,
'n'
=>
'\n'
,
'r'
=>
'\r'
,
't'
=>
'\t'
,
'u'
=>
{
// Unicode escape
self
.advance
();
let
hex
=
self
.parse_unicode_escape
()
?
;
string
.push
(
hex
);
escaped
=
false
;
continue
;
}
_
=>
ch
,
// Invalid escape, but be lenient
};
string
.push
(
escaped_char
);
escaped
=
false
;
}
else
if
ch
==
'\\'
{
escaped
=
true
;
}
else
if
ch
==
'"'
{
// End of string
self
.advance
();
return
Ok
(
Value
::
String
(
string
));
}
else
{
string
.push
(
ch
);
}
self
.advance
();
}
// Incomplete string
if
self
.allow_incomplete
{
Ok
(
Value
::
String
(
string
))
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Unterminated string"
.into
()))
}
}
fn
parse_unicode_escape
(
&
mut
self
)
->
ToolParserResult
<
char
>
{
let
mut
hex
=
String
::
new
();
for
_
in
0
..
4
{
if
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_ascii_hexdigit
()
{
hex
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
else
{
break
;
}
}
if
hex
.len
()
==
4
{
u32
::
from_str_radix
(
&
hex
,
16
)
.ok
()
.and_then
(
char
::
from_u32
)
.ok_or_else
(||
ToolParserError
::
ParsingFailed
(
"Invalid unicode escape"
.into
()))
}
else
if
self
.allow_incomplete
{
Ok
(
'\
u
{
FFFD
}
'
)
// Replacement character
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Incomplete unicode escape"
.into
(),
))
}
}
fn
parse_number
(
&
mut
self
)
->
ToolParserResult
<
Value
>
{
let
mut
number
=
String
::
new
();
// Handle negative sign
if
self
.peek
()
==
Some
(
'-'
)
{
number
.push
(
'-'
);
self
.advance
();
}
// Parse integer part
if
self
.peek
()
==
Some
(
'0'
)
{
number
.push
(
'0'
);
self
.advance
();
}
else
{
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_ascii_digit
()
{
number
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
}
// Parse decimal part
if
self
.peek
()
==
Some
(
'.'
)
{
number
.push
(
'.'
);
self
.advance
();
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_ascii_digit
()
{
number
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
}
// Parse exponent
if
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
==
'e'
||
ch
==
'E'
{
number
.push
(
ch
);
self
.advance
();
if
let
Some
(
sign
)
=
self
.peek
()
{
if
sign
==
'+'
||
sign
==
'-'
{
number
.push
(
sign
);
self
.advance
();
}
}
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_ascii_digit
()
{
number
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
}
}
// Try to parse as integer first, then as float
if
let
Ok
(
n
)
=
number
.parse
::
<
i64
>
()
{
Ok
(
Value
::
Number
(
serde_json
::
Number
::
from
(
n
)))
}
else
if
let
Ok
(
n
)
=
number
.parse
::
<
f64
>
()
{
Ok
(
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
n
)
.unwrap_or_else
(||
serde_json
::
Number
::
from
(
0
)),
))
}
else
if
self
.allow_incomplete
{
Ok
(
Value
::
Number
(
serde_json
::
Number
::
from
(
0
)))
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid number"
.into
()))
}
}
fn
parse_bool
(
&
mut
self
)
->
ToolParserResult
<
Value
>
{
let
mut
word
=
String
::
new
();
// Peek at upcoming characters to validate it looks like a boolean
let
mut
temp_chars
=
self
.chars
.clone
();
while
let
Some
(
&
ch
)
=
temp_chars
.peek
()
{
if
ch
.is_alphabetic
()
&&
word
.len
()
<
5
{
// "false" is 5 chars
word
.push
(
ch
);
temp_chars
.next
();
}
else
{
break
;
}
}
// Check if it's a valid boolean prefix
let
is_valid
=
word
==
"true"
||
word
==
"false"
||
(
self
.allow_incomplete
&&
(
"true"
.starts_with
(
&
word
)
||
"false"
.starts_with
(
&
word
)));
if
!
is_valid
{
return
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid boolean"
.into
()));
}
// Now actually consume the characters
word
.clear
();
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_alphabetic
()
{
word
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
match
word
.as_str
()
{
"true"
=>
Ok
(
Value
::
Bool
(
true
)),
"false"
=>
Ok
(
Value
::
Bool
(
false
)),
partial
if
self
.allow_incomplete
=>
{
if
"true"
.starts_with
(
partial
)
{
Ok
(
Value
::
Bool
(
true
))
}
else
if
"false"
.starts_with
(
partial
)
{
Ok
(
Value
::
Bool
(
false
))
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid boolean"
.into
()))
}
}
_
=>
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid boolean"
.into
())),
}
}
fn
parse_null
(
&
mut
self
)
->
ToolParserResult
<
Value
>
{
let
mut
word
=
String
::
new
();
// Peek at upcoming characters to validate it looks like "null"
let
mut
temp_chars
=
self
.chars
.clone
();
while
let
Some
(
&
ch
)
=
temp_chars
.peek
()
{
if
ch
.is_alphabetic
()
&&
word
.len
()
<
4
{
// "null" is 4 chars
word
.push
(
ch
);
temp_chars
.next
();
}
else
{
break
;
}
}
// Check if it's a valid null prefix
let
is_valid
=
word
==
"null"
||
(
self
.allow_incomplete
&&
"null"
.starts_with
(
&
word
));
if
!
is_valid
{
return
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid null"
.into
()));
}
// Now actually consume the characters
word
.clear
();
while
let
Some
(
ch
)
=
self
.peek
()
{
if
ch
.is_alphabetic
()
{
word
.push
(
ch
);
self
.advance
();
}
else
{
break
;
}
}
if
word
==
"null"
||
(
self
.allow_incomplete
&&
"null"
.starts_with
(
&
word
))
{
Ok
(
Value
::
Null
)
}
else
{
Err
(
ToolParserError
::
ParsingFailed
(
"Invalid null"
.into
()))
}
}
}
/// Utility function to check if a string contains complete JSON
pub
fn
is_complete_json
(
input
:
&
str
)
->
bool
{
serde_json
::
from_str
::
<
Value
>
(
input
)
.is_ok
()
}
/// Utility function to find common prefix between two strings
pub
fn
find_common_prefix
(
s1
:
&
str
,
s2
:
&
str
)
->
usize
{
s1
.chars
()
.zip
(
s2
.chars
())
.take_while
(|(
a
,
b
)|
a
==
b
)
.count
()
}
/// Utility function to compute diff between old and new strings
pub
fn
compute_diff
(
old
:
&
str
,
new
:
&
str
)
->
String
{
let
common_len
=
find_common_prefix
(
old
,
new
);
// Convert character count to byte offset
new
.chars
()
.skip
(
common_len
)
.collect
()
}
sgl-router/src/tool_parser/registry.rs
0 → 100644
View file @
816c4c85
use
crate
::
tool_parser
::
traits
::
ToolParser
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
/// Registry for tool parsers and model mappings
pub
struct
ParserRegistry
{
/// Map of parser name to parser instance
parsers
:
HashMap
<
String
,
Arc
<
dyn
ToolParser
>>
,
/// Map of model name/pattern to parser name
model_mapping
:
HashMap
<
String
,
String
>
,
/// Default parser to use when no match found
default_parser
:
String
,
}
impl
ParserRegistry
{
/// Create a new parser registry with default mappings
pub
fn
new
()
->
Self
{
let
mut
registry
=
Self
{
parsers
:
HashMap
::
new
(),
model_mapping
:
HashMap
::
new
(),
default_parser
:
"json"
.to_string
(),
};
// Register default model mappings
registry
.register_default_mappings
();
registry
}
/// Register a parser
pub
fn
register_parser
(
&
mut
self
,
name
:
impl
Into
<
String
>
,
parser
:
Arc
<
dyn
ToolParser
>
)
{
self
.parsers
.insert
(
name
.into
(),
parser
);
}
/// Map a model name/pattern to a parser
pub
fn
map_model
(
&
mut
self
,
model
:
impl
Into
<
String
>
,
parser
:
impl
Into
<
String
>
)
{
self
.model_mapping
.insert
(
model
.into
(),
parser
.into
());
}
/// Get parser for a specific model
pub
fn
get_parser
(
&
self
,
model
:
&
str
)
->
Option
<
Arc
<
dyn
ToolParser
>>
{
// Try exact match first
if
let
Some
(
parser_name
)
=
self
.model_mapping
.get
(
model
)
{
if
let
Some
(
parser
)
=
self
.parsers
.get
(
parser_name
)
{
return
Some
(
parser
.clone
());
}
}
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
for
(
pattern
,
parser_name
)
in
&
self
.model_mapping
{
if
pattern
.ends_with
(
'*'
)
{
let
prefix
=
&
pattern
[
..
pattern
.len
()
-
1
];
if
model
.starts_with
(
prefix
)
{
if
let
Some
(
parser
)
=
self
.parsers
.get
(
parser_name
)
{
return
Some
(
parser
.clone
());
}
}
}
}
// Fall back to default parser if it exists
self
.parsers
.get
(
&
self
.default_parser
)
.cloned
()
}
/// List all registered parsers
pub
fn
list_parsers
(
&
self
)
->
Vec
<&
str
>
{
self
.parsers
.keys
()
.map
(|
s
|
s
.as_str
())
.collect
()
}
/// List all model mappings
pub
fn
list_mappings
(
&
self
)
->
Vec
<
(
&
str
,
&
str
)
>
{
self
.model_mapping
.iter
()
.map
(|(
k
,
v
)|
(
k
.as_str
(),
v
.as_str
()))
.collect
()
}
/// Register default model mappings
fn
register_default_mappings
(
&
mut
self
)
{
// OpenAI models
self
.map_model
(
"gpt-4*"
,
"json"
);
self
.map_model
(
"gpt-3.5*"
,
"json"
);
self
.map_model
(
"gpt-4o*"
,
"json"
);
// Anthropic models
self
.map_model
(
"claude-*"
,
"json"
);
// Mistral models
self
.map_model
(
"mistral-*"
,
"mistral"
);
self
.map_model
(
"mixtral-*"
,
"mistral"
);
// Qwen models
self
.map_model
(
"qwen*"
,
"qwen"
);
// Llama models
self
.map_model
(
"llama-*"
,
"llama"
);
self
.map_model
(
"meta-llama-*"
,
"llama"
);
// Other models default to JSON
self
.map_model
(
"gemini-*"
,
"json"
);
self
.map_model
(
"palm-*"
,
"json"
);
}
/// Set the default parser
pub
fn
set_default_parser
(
&
mut
self
,
name
:
impl
Into
<
String
>
)
{
self
.default_parser
=
name
.into
();
}
/// Check if a parser is registered
pub
fn
has_parser
(
&
self
,
name
:
&
str
)
->
bool
{
self
.parsers
.contains_key
(
name
)
}
}
impl
Default
for
ParserRegistry
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
sgl-router/src/tool_parser/state.rs
0 → 100644
View file @
816c4c85
use
crate
::
tool_parser
::
types
::{
PartialToolCall
,
ToolCall
};
/// Current phase of parsing
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
ParsePhase
{
/// Looking for start of tool call
Searching
,
/// Parsing function name
InName
,
/// Parsing function arguments
InArguments
,
/// Tool call complete
Complete
,
}
/// State for streaming parser
#[derive(Debug,
Clone)]
pub
struct
ParseState
{
/// Buffer for accumulating input
pub
buffer
:
String
,
/// Position of last consumed character
pub
consumed
:
usize
,
/// Current partial tool being parsed
pub
partial_tool
:
Option
<
PartialToolCall
>
,
/// Completed tool calls
pub
completed_tools
:
Vec
<
ToolCall
>
,
/// Current parsing phase
pub
phase
:
ParsePhase
,
/// Bracket/brace depth for JSON parsing
pub
bracket_depth
:
i32
,
/// Whether currently inside a string literal
pub
in_string
:
bool
,
/// Whether next character should be escaped
pub
escape_next
:
bool
,
/// Current tool index (for streaming)
pub
tool_index
:
usize
,
}
impl
ParseState
{
/// Create a new parse state
pub
fn
new
()
->
Self
{
Self
{
buffer
:
String
::
new
(),
consumed
:
0
,
partial_tool
:
None
,
completed_tools
:
Vec
::
new
(),
phase
:
ParsePhase
::
Searching
,
bracket_depth
:
0
,
in_string
:
false
,
escape_next
:
false
,
tool_index
:
0
,
}
}
/// Reset state for parsing next tool
pub
fn
reset
(
&
mut
self
)
{
self
.partial_tool
=
None
;
self
.phase
=
ParsePhase
::
Searching
;
self
.bracket_depth
=
0
;
self
.in_string
=
false
;
self
.escape_next
=
false
;
}
/// Process a single character for JSON parsing
pub
fn
process_char
(
&
mut
self
,
ch
:
char
)
{
// Handle escape sequences
if
self
.escape_next
{
self
.escape_next
=
false
;
self
.buffer
.push
(
ch
);
return
;
}
if
ch
==
'\\'
&&
self
.in_string
{
self
.escape_next
=
true
;
self
.buffer
.push
(
ch
);
return
;
}
// Track string boundaries
if
ch
==
'"'
&&
!
self
.escape_next
{
self
.in_string
=
!
self
.in_string
;
}
// Track bracket depth for JSON
if
!
self
.in_string
{
match
ch
{
'{'
|
'['
=>
{
self
.bracket_depth
+=
1
;
}
'}'
|
']'
=>
{
self
.bracket_depth
-=
1
;
if
self
.bracket_depth
==
0
&&
self
.partial_tool
.is_some
()
{
// Complete tool call found
self
.phase
=
ParsePhase
::
Complete
;
}
}
_
=>
{}
}
}
self
.buffer
.push
(
ch
);
}
/// Check if we have a complete JSON object/array
pub
fn
has_complete_json
(
&
self
)
->
bool
{
self
.bracket_depth
==
0
&&
!
self
.in_string
&&
!
self
.buffer
.is_empty
()
}
/// Extract content from buffer starting at position
pub
fn
extract_from
(
&
self
,
start
:
usize
)
->
&
str
{
if
start
>=
self
.buffer
.len
()
{
return
""
;
}
// Find the nearest character boundary at or after start
let
mut
safe_start
=
start
;
while
safe_start
<
self
.buffer
.len
()
&&
!
self
.buffer
.is_char_boundary
(
safe_start
)
{
safe_start
+=
1
;
}
if
safe_start
<
self
.buffer
.len
()
{
&
self
.buffer
[
safe_start
..
]
}
else
{
""
}
}
/// Mark content as consumed up to position
pub
fn
consume_to
(
&
mut
self
,
position
:
usize
)
{
if
position
>
self
.consumed
{
self
.consumed
=
position
;
}
}
/// Get unconsumed content
pub
fn
unconsumed
(
&
self
)
->
&
str
{
if
self
.consumed
>=
self
.buffer
.len
()
{
return
""
;
}
// Find the nearest character boundary at or after consumed
let
mut
safe_consumed
=
self
.consumed
;
while
safe_consumed
<
self
.buffer
.len
()
&&
!
self
.buffer
.is_char_boundary
(
safe_consumed
)
{
safe_consumed
+=
1
;
}
if
safe_consumed
<
self
.buffer
.len
()
{
&
self
.buffer
[
safe_consumed
..
]
}
else
{
""
}
}
/// Clear consumed content from buffer
pub
fn
clear_consumed
(
&
mut
self
)
{
if
self
.consumed
>
0
{
// Find the nearest character boundary at or before consumed
let
mut
safe_consumed
=
self
.consumed
;
while
safe_consumed
>
0
&&
!
self
.buffer
.is_char_boundary
(
safe_consumed
)
{
safe_consumed
-=
1
;
}
if
safe_consumed
>
0
{
self
.buffer
.drain
(
..
safe_consumed
);
self
.consumed
=
self
.consumed
.saturating_sub
(
safe_consumed
);
}
}
}
/// Add completed tool
pub
fn
add_completed_tool
(
&
mut
self
,
tool
:
ToolCall
)
{
self
.completed_tools
.push
(
tool
);
self
.tool_index
+=
1
;
}
}
impl
Default
for
ParseState
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
sgl-router/src/tool_parser/tests.rs
0 → 100644
View file @
816c4c85
use
super
::
*
;
use
crate
::
tool_parser
::
partial_json
::{
compute_diff
,
find_common_prefix
,
is_complete_json
,
PartialJson
,
};
#[test]
fn
test_parse_state_new
()
{
let
state
=
ParseState
::
new
();
assert_eq!
(
state
.phase
,
ParsePhase
::
Searching
);
assert_eq!
(
state
.buffer
,
""
);
assert_eq!
(
state
.consumed
,
0
);
assert_eq!
(
state
.bracket_depth
,
0
);
assert
!
(
!
state
.in_string
);
assert
!
(
!
state
.escape_next
);
}
#[test]
fn
test_parse_state_process_char
()
{
let
mut
state
=
ParseState
::
new
();
// Test bracket tracking
state
.process_char
(
'{'
);
assert_eq!
(
state
.bracket_depth
,
1
);
state
.process_char
(
'}'
);
assert_eq!
(
state
.bracket_depth
,
0
);
// Test string tracking
state
.process_char
(
'"'
);
assert
!
(
state
.in_string
);
state
.process_char
(
'"'
);
assert
!
(
!
state
.in_string
);
// Test escape handling
state
.process_char
(
'"'
);
state
.process_char
(
'\\'
);
assert
!
(
state
.escape_next
);
state
.process_char
(
'"'
);
assert
!
(
!
state
.escape_next
);
assert
!
(
state
.in_string
);
// Still in string because quote was escaped
}
#[test]
fn
test_token_config
()
{
let
config
=
TokenConfig
{
start_tokens
:
vec!
[
"<start>"
.to_string
(),
"["
.to_string
()],
end_tokens
:
vec!
[
"</end>"
.to_string
(),
"]"
.to_string
()],
separator
:
", "
.to_string
(),
};
let
pairs
:
Vec
<
_
>
=
config
.iter_pairs
()
.collect
();
assert_eq!
(
pairs
.len
(),
2
);
assert_eq!
(
pairs
[
0
],
(
"<start>"
,
"</end>"
));
assert_eq!
(
pairs
[
1
],
(
"["
,
"]"
));
}
#[test]
fn
test_parser_registry
()
{
let
registry
=
ParserRegistry
::
new
();
// Test has default mappings
assert
!
(
!
registry
.list_mappings
()
.is_empty
());
// Test model pattern matching
let
mappings
=
registry
.list_mappings
();
let
has_gpt
=
mappings
.iter
()
.any
(|(
m
,
_
)|
m
.starts_with
(
"gpt"
));
assert
!
(
has_gpt
);
}
#[test]
fn
test_parser_registry_pattern_matching
()
{
let
mut
registry
=
ParserRegistry
::
new
();
// Test that model mappings work by checking the list
registry
.map_model
(
"test-model"
,
"json"
);
// Verify through list_mappings
let
mappings
=
registry
.list_mappings
();
let
has_test
=
mappings
.iter
()
.any
(|(
m
,
p
)|
*
m
==
"test-model"
&&
*
p
==
"json"
);
assert
!
(
has_test
);
}
#[test]
fn
test_tool_call_serialization
()
{
let
tool_call
=
ToolCall
{
id
:
"call-123"
.to_string
(),
r
#
type
:
"function"
.to_string
(),
function
:
FunctionCall
{
name
:
"search"
.to_string
(),
arguments
:
r#"{"query": "rust programming"}"#
.to_string
(),
},
};
let
json
=
serde_json
::
to_string
(
&
tool_call
)
.unwrap
();
assert
!
(
json
.contains
(
"call-123"
));
assert
!
(
json
.contains
(
"search"
));
assert
!
(
json
.contains
(
"rust programming"
));
let
parsed
:
ToolCall
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
assert_eq!
(
parsed
.id
,
"call-123"
);
assert_eq!
(
parsed
.function.name
,
"search"
);
}
#[test]
fn
test_partial_json_parser
()
{
let
parser
=
PartialJson
::
default
();
// Test complete JSON
let
input
=
r#"{"name": "test", "value": 42}"#
;
let
(
value
,
consumed
)
=
parser
.parse_value
(
input
)
.unwrap
();
assert_eq!
(
value
[
"name"
],
"test"
);
assert_eq!
(
value
[
"value"
],
42
);
assert_eq!
(
consumed
,
input
.len
());
// Test incomplete JSON object
let
input
=
r#"{"name": "test", "value": "#
;
let
(
value
,
_
consumed
)
=
parser
.parse_value
(
input
)
.unwrap
();
assert_eq!
(
value
[
"name"
],
"test"
);
assert
!
(
value
[
"value"
]
.is_null
());
// Test incomplete string
let
input
=
r#"{"name": "tes"#
;
let
(
value
,
_
consumed
)
=
parser
.parse_value
(
input
)
.unwrap
();
assert_eq!
(
value
[
"name"
],
"tes"
);
// Test incomplete array
let
input
=
r#"[1, 2, "#
;
let
(
value
,
_
consumed
)
=
parser
.parse_value
(
input
)
.unwrap
();
assert
!
(
value
.is_array
());
assert_eq!
(
value
[
0
],
1
);
assert_eq!
(
value
[
1
],
2
);
}
#[test]
fn
test_partial_json_depth_limit
()
{
// max_depth of 3 allows nesting up to 3 levels
// Set allow_incomplete to false to get errors instead of partial results
let
parser
=
PartialJson
::
new
(
3
,
false
);
// This should work (simple object)
let
input
=
r#"{"a": 1}"#
;
let
result
=
parser
.parse_value
(
input
);
assert
!
(
result
.is_ok
());
// This should work (nested to depth 3)
let
input
=
r#"{"a": {"b": {"c": 1}}}"#
;
let
result
=
parser
.parse_value
(
input
);
assert
!
(
result
.is_ok
());
// This should fail (nested to depth 4, exceeds limit)
let
input
=
r#"{"a": {"b": {"c": {"d": 1}}}}"#
;
let
result
=
parser
.parse_value
(
input
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_is_complete_json
()
{
assert
!
(
is_complete_json
(
r#"{"name": "test"}"#
));
assert
!
(
is_complete_json
(
r#"[1, 2, 3]"#
));
assert
!
(
is_complete_json
(
r#""string""#
));
assert
!
(
is_complete_json
(
"42"
));
assert
!
(
is_complete_json
(
"true"
));
assert
!
(
is_complete_json
(
"null"
));
assert
!
(
!
is_complete_json
(
r#"{"name": "#
));
assert
!
(
!
is_complete_json
(
r#"[1, 2, "#
));
assert
!
(
!
is_complete_json
(
r#""unclosed"#
));
}
#[test]
fn
test_find_common_prefix
()
{
assert_eq!
(
find_common_prefix
(
"hello"
,
"hello"
),
5
);
assert_eq!
(
find_common_prefix
(
"hello"
,
"help"
),
3
);
assert_eq!
(
find_common_prefix
(
"hello"
,
"world"
),
0
);
assert_eq!
(
find_common_prefix
(
""
,
"hello"
),
0
);
assert_eq!
(
find_common_prefix
(
"hello"
,
""
),
0
);
}
#[test]
fn
test_compute_diff
()
{
assert_eq!
(
compute_diff
(
"hello"
,
"hello world"
),
" world"
);
assert_eq!
(
compute_diff
(
""
,
"hello"
),
"hello"
);
assert_eq!
(
compute_diff
(
"hello"
,
"hello"
),
""
);
assert_eq!
(
compute_diff
(
"test"
,
"hello"
),
"hello"
);
}
#[test]
fn
test_stream_result_variants
()
{
// Test Incomplete
let
result
=
StreamResult
::
Incomplete
;
matches!
(
result
,
StreamResult
::
Incomplete
);
// Test ToolName
let
result
=
StreamResult
::
ToolName
{
index
:
0
,
name
:
"test"
.to_string
(),
};
if
let
StreamResult
::
ToolName
{
index
,
name
}
=
result
{
assert_eq!
(
index
,
0
);
assert_eq!
(
name
,
"test"
);
}
else
{
panic!
(
"Expected ToolName variant"
);
}
// Test ToolComplete
let
tool
=
ToolCall
{
id
:
"123"
.to_string
(),
r
#
type
:
"function"
.to_string
(),
function
:
FunctionCall
{
name
:
"test"
.to_string
(),
arguments
:
"{}"
.to_string
(),
},
};
let
result
=
StreamResult
::
ToolComplete
(
tool
.clone
());
if
let
StreamResult
::
ToolComplete
(
t
)
=
result
{
assert_eq!
(
t
.id
,
"123"
);
}
else
{
panic!
(
"Expected ToolComplete variant"
);
}
}
#[test]
fn
test_partial_tool_call
()
{
let
mut
partial
=
PartialToolCall
{
name
:
None
,
arguments_buffer
:
String
::
new
(),
start_position
:
0
,
name_sent
:
false
,
streamed_args
:
String
::
new
(),
};
// Set name
partial
.name
=
Some
(
"test_function"
.to_string
());
assert_eq!
(
partial
.name
.as_ref
()
.unwrap
(),
"test_function"
);
// Append arguments
partial
.arguments_buffer
.push_str
(
r#"{"key": "value"}"#
);
assert_eq!
(
partial
.arguments_buffer
,
r#"{"key": "value"}"#
);
// Update streaming state
partial
.name_sent
=
true
;
partial
.streamed_args
=
r#"{"key": "#
.to_string
();
assert
!
(
partial
.name_sent
);
assert_eq!
(
partial
.streamed_args
,
r#"{"key": "#
);
}
sgl-router/src/tool_parser/traits.rs
0 → 100644
View file @
816c4c85
use
crate
::
tool_parser
::{
errors
::
ToolParserResult
,
state
::
ParseState
,
types
::{
StreamResult
,
ToolCall
},
};
use
async_trait
::
async_trait
;
/// Core trait for all tool parsers
#[async_trait]
pub
trait
ToolParser
:
Send
+
Sync
{
/// Parse complete tool calls from final output
async
fn
parse_complete
(
&
self
,
output
:
&
str
)
->
ToolParserResult
<
Vec
<
ToolCall
>>
;
/// Parse tool calls from model output (streaming)
async
fn
parse_incremental
(
&
self
,
chunk
:
&
str
,
state
:
&
mut
ParseState
,
)
->
ToolParserResult
<
StreamResult
>
;
/// Check if text contains tool calls in this parser's format
fn
detect_format
(
&
self
,
text
:
&
str
)
->
bool
;
}
/// Trait for partial JSON parsing
pub
trait
PartialJsonParser
:
Send
+
Sync
{
/// Parse potentially incomplete JSON
fn
parse
(
&
self
,
input
:
&
str
)
->
ToolParserResult
<
(
serde_json
::
Value
,
usize
)
>
;
/// Check if JSON is complete
fn
is_complete
(
&
self
,
input
:
&
str
)
->
bool
;
/// Get the maximum parsing depth
fn
max_depth
(
&
self
)
->
usize
;
}
sgl-router/src/tool_parser/types.rs
0 → 100644
View file @
816c4c85
use
serde
::{
Deserialize
,
Serialize
};
/// Parsed tool call from model output (OpenAI format)
#[derive(Debug,
Clone,
Serialize,
Deserialize,
PartialEq)]
pub
struct
ToolCall
{
/// Unique identifier for the tool call
pub
id
:
String
,
/// Type of tool call (currently always "function")
#[serde(rename
=
"type"
)]
pub
r
#
type
:
String
,
/// Function call details
pub
function
:
FunctionCall
,
}
/// Function call within a tool call
#[derive(Debug,
Clone,
Serialize,
Deserialize,
PartialEq)]
pub
struct
FunctionCall
{
/// Name of the function to call
pub
name
:
String
,
/// Arguments as JSON string
pub
arguments
:
String
,
}
/// Streaming parse result
#[derive(Debug,
Clone)]
pub
enum
StreamResult
{
/// Need more data to continue parsing
Incomplete
,
/// Found a tool name (for streaming)
ToolName
{
index
:
usize
,
name
:
String
},
/// Found incremental arguments (for streaming)
ToolArguments
{
index
:
usize
,
arguments
:
String
},
/// Completed parsing a tool
ToolComplete
(
ToolCall
),
/// Normal text (not part of tool call)
NormalText
(
String
),
}
/// Token configuration for parsing
#[derive(Debug,
Clone)]
pub
struct
TokenConfig
{
/// Start tokens for tool calls
pub
start_tokens
:
Vec
<
String
>
,
/// End tokens for tool calls
pub
end_tokens
:
Vec
<
String
>
,
/// Separator between multiple tool calls
pub
separator
:
String
,
}
impl
TokenConfig
{
/// Iterate over start/end token pairs
pub
fn
iter_pairs
(
&
self
)
->
impl
Iterator
<
Item
=
(
&
str
,
&
str
)
>
{
self
.start_tokens
.iter
()
.zip
(
self
.end_tokens
.iter
())
.map
(|(
s
,
e
)|
(
s
.as_str
(),
e
.as_str
()))
}
}
/// Simple partial tool call for streaming
#[derive(Debug,
Clone)]
pub
struct
PartialToolCall
{
/// Tool name (if parsed)
pub
name
:
Option
<
String
>
,
/// Buffer for accumulating arguments
pub
arguments_buffer
:
String
,
/// Start position in the input buffer
pub
start_position
:
usize
,
/// Whether the name has been sent (for streaming)
pub
name_sent
:
bool
,
/// Arguments already streamed
pub
streamed_args
:
String
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment