Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
cbe854fc
Unverified
Commit
cbe854fc
authored
Aug 22, 2025
by
Ayush Agarwal
Committed by
GitHub
Aug 22, 2025
Browse files
feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)
parent
b658ba61
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
183 additions
and
58 deletions
+183
-58
components/backends/vllm/src/dynamo/vllm/args.py
components/backends/vllm/src/dynamo/vllm/args.py
+19
-1
components/backends/vllm/src/dynamo/vllm/main.py
components/backends/vllm/src/dynamo/vllm/main.py
+2
-0
lib/bindings/python/rust/llm/local_model.rs
lib/bindings/python/rust/llm/local_model.rs
+20
-0
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+12
-0
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+43
-27
lib/llm/src/local_model.rs
lib/llm/src/local_model.rs
+2
-0
lib/llm/src/local_model/runtime_config.rs
lib/llm/src/local_model/runtime_config.rs
+4
-0
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+0
-1
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+16
-0
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+19
-9
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+13
-7
lib/llm/tests/aggregators.rs
lib/llm/tests/aggregators.rs
+28
-13
lib/parsers/src/tool_calling/tools.rs
lib/parsers/src/tool_calling/tools.rs
+5
-0
No files found.
components/backends/vllm/src/dynamo/vllm/args.py
View file @
cbe854fc
...
...
@@ -58,6 +58,10 @@ class Config:
# Connector list from CLI
connector_list
:
Optional
[
list
]
=
None
# tool and reasoning parser info
tool_call_parser
:
Optional
[
str
]
=
None
reasoning_parser
:
Optional
[
str
]
=
None
def
parse_args
()
->
Config
:
parser
=
FlexibleArgumentParser
(
...
...
@@ -102,6 +106,19 @@ def parse_args() -> Config:
help
=
"List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector."
,
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser
.
add_argument
(
"--dyn-tool-call-parser"
,
type
=
str
,
default
=
None
,
help
=
"Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'."
,
)
parser
.
add_argument
(
"--dyn-reasoning-parser"
,
type
=
str
,
default
=
None
,
help
=
"Reasoning parser name for the model."
,
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
@@ -151,7 +168,8 @@ def parse_args() -> Config:
config
.
port_range
=
DynamoPortRange
(
min
=
args
.
dynamo_port_min
,
max
=
args
.
dynamo_port_max
)
config
.
tool_call_parser
=
args
.
dyn_tool_call_parser
config
.
reasoning_parser
=
args
.
dyn_reasoning_parser
# Check for conflicting flags
has_kv_transfer_config
=
(
hasattr
(
engine_args
,
"kv_transfer_config"
)
...
...
components/backends/vllm/src/dynamo/vllm/main.py
View file @
cbe854fc
...
...
@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config
.
total_kv_blocks
=
runtime_values
[
"num_gpu_blocks"
]
runtime_config
.
max_num_seqs
=
runtime_values
[
"max_num_seqs"
]
runtime_config
.
max_num_batched_tokens
=
runtime_values
[
"max_num_batched_tokens"
]
runtime_config
.
tool_call_parser
=
config
.
tool_call_parser
runtime_config
.
reasoning_parser
=
config
.
reasoning_parser
await
register_llm
(
ModelType
.
Backend
,
...
...
lib/bindings/python/rust/llm/local_model.rs
View file @
cbe854fc
...
...
@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
self
.inner.max_num_batched_tokens
=
Some
(
max_num_batched_tokens
);
}
#[setter]
fn
set_tool_call_parser
(
&
mut
self
,
tool_call_parser
:
Option
<
String
>
)
{
self
.inner.tool_call_parser
=
tool_call_parser
;
}
#[setter]
fn
set_reasoning_parser
(
&
mut
self
,
reasoning_parser
:
Option
<
String
>
)
{
self
.inner.reasoning_parser
=
reasoning_parser
;
}
fn
set_engine_specific
(
&
mut
self
,
key
:
&
str
,
value
:
String
)
->
PyResult
<
()
>
{
let
value
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
value
)
.map_err
(
to_pyerr
)
?
;
self
.inner
...
...
@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
self
.inner.max_num_batched_tokens
}
#[getter]
fn
tool_call_parser
(
&
self
)
->
Option
<
String
>
{
self
.inner.tool_call_parser
.clone
()
}
#[getter]
fn
reasoning_parser
(
&
self
)
->
Option
<
String
>
{
self
.inner.reasoning_parser
.clone
()
}
#[getter]
fn
runtime_data
(
&
self
,
py
:
Python
<
'_
>
)
->
PyResult
<
PyObject
>
{
let
dict
=
PyDict
::
new
(
py
);
...
...
lib/llm/src/discovery/model_manager.rs
View file @
cbe854fc
...
...
@@ -246,6 +246,18 @@ impl ModelManager {
.insert
(
model_name
.to_string
(),
new_kv_chooser
.clone
());
Ok
(
new_kv_chooser
)
}
pub
fn
get_model_tool_call_parser
(
&
self
,
model
:
&
str
)
->
Option
<
String
>
{
match
self
.entries
.lock
()
{
Ok
(
entries
)
=>
entries
.values
()
.find
(|
entry
|
entry
.name
==
model
)
.and_then
(|
entry
|
entry
.runtime_config
.as_ref
())
.and_then
(|
config
|
config
.tool_call_parser
.clone
())
.map
(|
parser
|
parser
.to_string
()),
Err
(
_
)
=>
None
,
}
}
}
pub
struct
ModelEngines
<
E
>
{
...
...
lib/llm/src/http/service/openai.rs
View file @
cbe854fc
...
...
@@ -37,6 +37,7 @@ use crate::protocols::openai::{
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
responses
::{
NvCreateResponse
,
NvResponse
},
ParsingOptions
,
};
use
crate
::
request_template
::
RequestTemplate
;
use
crate
::
types
::
Annotated
;
...
...
@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
uuid
.to_string
()
}
fn
get_parsing_options
(
state
:
&
Arc
<
service_v2
::
State
>
,
model
:
&
str
)
->
ParsingOptions
{
let
tool_call_parser
=
state
.manager
()
.get_model_tool_call_parser
(
model
);
let
reasoning_parser
=
None
;
// TODO: Implement reasoning parser
ParsingOptions
::
new
(
tool_call_parser
,
reasoning_parser
)
}
/// OpenAI Completions Request Handler
///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
...
...
@@ -267,6 +275,8 @@ async fn completions(
.get_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
state
.metrics_clone
()
...
...
@@ -325,7 +335,7 @@ async fn completions(
process_metrics_only
(
response
,
&
mut
response_collector
);
});
let
response
=
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
)
let
response
=
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
.map_err
(|
e
|
{
tracing
::
error!
(
...
...
@@ -494,6 +504,8 @@ async fn chat_completions(
.get_chat_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
state
.metrics_clone
()
...
...
@@ -553,19 +565,20 @@ async fn chat_completions(
process_metrics_only
(
response
,
&
mut
response_collector
);
});
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
.await
.map_err
(|
e
|
{
tracing
::
error!
(
request_id
,
"Failed to fold chat completions stream for: {:?}"
,
e
);
ErrorMessage
::
internal_server_error
(
&
format!
(
"Failed to fold chat completions stream: {}"
,
e
))
})
?
;
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
.clone
())
.await
.map_err
(|
e
|
{
tracing
::
error!
(
request_id
,
"Failed to fold chat completions stream for: {:?}"
,
e
);
ErrorMessage
::
internal_server_error
(
&
format!
(
"Failed to fold chat completions stream: {}"
,
e
))
})
?
;
inflight_guard
.mark_ok
();
Ok
(
Json
(
response
)
.into_response
())
...
...
@@ -726,6 +739,8 @@ async fn responses(
.get_chat_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
state
.metrics_clone
()
...
...
@@ -742,19 +757,20 @@ async fn responses(
.map_err
(|
e
|
ErrorMessage
::
from_anyhow
(
e
,
"Failed to generate completions"
))
?
;
// TODO: handle streaming, currently just unary
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
.await
.map_err
(|
e
|
{
tracing
::
error!
(
request_id
,
"Failed to fold chat completions stream for: {:?}"
,
e
);
ErrorMessage
::
internal_server_error
(
&
format!
(
"Failed to fold chat completions stream: {}"
,
e
))
})
?
;
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
.clone
())
.await
.map_err
(|
e
|
{
tracing
::
error!
(
request_id
,
"Failed to fold chat completions stream for: {:?}"
,
e
);
ErrorMessage
::
internal_server_error
(
&
format!
(
"Failed to fold chat completions stream: {}"
,
e
))
})
?
;
// Convert NvCreateChatCompletionResponse --> NvResponse
let
response
:
NvResponse
=
response
.try_into
()
.map_err
(|
e
|
{
...
...
lib/llm/src/local_model.rs
View file @
cbe854fc
...
...
@@ -202,6 +202,7 @@ impl LocalModelBuilder {
);
card
.migration_limit
=
self
.migration_limit
;
card
.user_data
=
self
.user_data
.take
();
return
Ok
(
LocalModel
{
card
,
full_path
:
PathBuf
::
new
(),
...
...
@@ -392,6 +393,7 @@ impl LocalModel {
let
kvstore
:
Box
<
dyn
KeyValueStore
>
=
Box
::
new
(
EtcdStorage
::
new
(
etcd_client
.clone
()));
let
card_store
=
Arc
::
new
(
KeyValueStoreManager
::
new
(
kvstore
));
let
key
=
self
.card
.slug
()
.to_string
();
card_store
.publish
(
model_card
::
ROOT_PATH
,
None
,
&
key
,
&
mut
self
.card
)
.await
?
;
...
...
lib/llm/src/local_model/runtime_config.rs
View file @
cbe854fc
...
...
@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
pub
max_num_batched_tokens
:
Option
<
u64
>
,
pub
tool_call_parser
:
Option
<
String
>
,
pub
reasoning_parser
:
Option
<
String
>
,
/// Mapping of engine-specific runtime configs
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
...
...
lib/llm/src/preprocessor.rs
View file @
cbe854fc
...
...
@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
let
mdcsum
=
mdc
.mdcsum
();
let
formatter
=
PromptFormatter
::
from_mdc
(
mdc
.clone
())
.await
?
;
let
PromptFormatter
::
OAI
(
formatter
)
=
formatter
;
let
tokenizer
=
match
&
mdc
.tokenizer
{
Some
(
TokenizerKind
::
HfTokenizerJson
(
file
))
=>
HuggingFaceTokenizer
::
from_file
(
file
)
?
,
Some
(
TokenizerKind
::
GGUF
(
tokenizer
))
=>
{
...
...
lib/llm/src/protocols/openai.rs
View file @
cbe854fc
...
...
@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length).
fn
get_isl
(
&
self
)
->
Option
<
u32
>
;
}
#[derive(Clone,
Debug,
Serialize,
Deserialize,
Default)]
pub
struct
ParsingOptions
{
pub
tool_call_parser
:
Option
<
String
>
,
pub
reasoning_parser
:
Option
<
String
>
,
}
impl
ParsingOptions
{
pub
fn
new
(
tool_call_parser
:
Option
<
String
>
,
reasoning_parser
:
Option
<
String
>
)
->
Self
{
Self
{
tool_call_parser
,
reasoning_parser
,
}
}
}
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
cbe854fc
...
...
@@ -19,7 +19,9 @@ use std::collections::HashMap;
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
convert_sse_stream
,
openai
::
ParsingOptions
,
Annotated
,
};
use
dynamo_parsers
::
tool_calling
::
try_tool_call_parse_aggregate
;
...
...
@@ -99,6 +101,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
...
...
@@ -175,7 +178,10 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for
choice
in
aggregator
.choices
.values_mut
()
{
if
choice
.tool_calls
.is_none
()
{
if
let
Ok
(
tool_calls
)
=
try_tool_call_parse_aggregate
(
&
choice
.text
,
None
)
{
if
let
Ok
(
tool_calls
)
=
try_tool_call_parse_aggregate
(
&
choice
.text
,
parsing_options
.tool_call_parser
.as_deref
(),
)
{
if
tool_calls
.is_empty
()
{
continue
;
}
...
...
@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
...
...
@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
}
impl
ChatCompletionAggregator
for
dynamo_async_openai
::
types
::
CreateChatCompletionResponse
{
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
DeltaAggregator
::
apply
(
stream
,
parsing_options
)
.await
}
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateChatCompletionStreamResponse
>
(
stream
);
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
.await
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
}
}
...
...
@@ -347,7 +357,7 @@ mod tests {
Box
::
pin
(
stream
::
empty
());
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -377,7 +387,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -421,7 +431,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -492,7 +502,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -550,7 +560,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
cbe854fc
...
...
@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
common
::
FinishReason
,
convert_sse_stream
,
Annotated
,
DataStream
,
convert_sse_stream
,
openai
::
ParsingOptions
,
Annotated
,
DataStream
,
};
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
...
...
@@ -65,7 +67,9 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub
async
fn
apply
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
tracing
::
debug!
(
"Tool Call Parser: {:?}"
,
parsing_options
.tool_call_parser
);
// TODO: remove this once completion has tool call support
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
let
delta
=
match
delta
.ok
()
{
...
...
@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
impl
NvCreateCompletionResponse
{
pub
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateCompletionResponse
>
(
stream
);
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
)
.await
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
}
pub
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
DeltaAggregator
::
apply
(
stream
)
.await
DeltaAggregator
::
apply
(
stream
,
parsing_options
)
.await
}
}
...
...
@@ -241,7 +247,7 @@ mod tests {
let
stream
:
DataStream
<
Annotated
<
NvCreateCompletionResponse
>>
=
Box
::
pin
(
stream
::
empty
());
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -265,7 +271,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -305,7 +311,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
@@ -365,7 +371,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
assert
!
(
result
.is_ok
());
...
...
lib/llm/tests/aggregators.rs
View file @
cbe854fc
...
...
@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{
openai
::{
chat_completions
::{
aggregator
::
ChatCompletionAggregator
,
NvCreateChatCompletionResponse
},
completions
::
NvCreateCompletionResponse
,
ParsingOptions
,
},
ContentProvider
,
DataStream
,
};
...
...
@@ -37,9 +38,12 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small
let
stream
=
create_message_stream
(
&
data
)
.take
(
16
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
.unwrap
();
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
(),
)
.await
.unwrap
();
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
...
...
@@ -59,9 +63,12 @@ async fn test_openai_chat_stream() {
#[tokio::test]
async
fn
test_openai_chat_edge_case_multi_line_data
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-multi-line-data"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
.unwrap
();
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
(),
)
.await
.unwrap
();
assert_eq!
(
result
...
...
@@ -79,9 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test]
async
fn
test_openai_chat_edge_case_comments_per_response
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-comments_per_response"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
.unwrap
();
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
(),
)
.await
.unwrap
();
assert_eq!
(
result
...
...
@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test]
async
fn
test_openai_chat_edge_case_invalid_deserialize_error
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/invalid-deserialize_error"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
;
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
(),
)
.await
;
assert
!
(
result
.is_err
());
// insta::assert_debug_snapshot!(result);
...
...
@@ -112,9 +126,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test]
async
fn
test_openai_cmpl_stream
()
{
let
stream
=
create_stream
(
CMPL_ROOT_PATH
,
"completion.streaming.1"
)
.take
(
16
);
let
result
=
NvCreateCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
.unwrap
();
let
result
=
NvCreateCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
())
.await
.unwrap
();
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
...
...
lib/parsers/src/tool_calling/tools.rs
View file @
cbe854fc
...
...
@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate(
message
:
&
str
,
parser_str
:
Option
<&
str
>
,
)
->
anyhow
::
Result
<
Vec
<
dynamo_async_openai
::
types
::
ChatCompletionMessageToolCall
>>
{
if
parser_str
.is_none
()
{
tracing
::
info!
(
"No tool parser provided. Trying parsing with default parser."
);
}
else
{
tracing
::
info!
(
"Using tool parser: {:?}"
,
parser_str
);
}
let
parsed
=
detect_and_parse_tool_call
(
message
,
parser_str
)
?
;
if
parsed
.is_empty
()
{
return
Ok
(
vec!
[]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment