Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
cbe854fc
Unverified
Commit
cbe854fc
authored
Aug 22, 2025
by
Ayush Agarwal
Committed by
GitHub
Aug 22, 2025
Browse files
feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)
parent
b658ba61
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
183 additions
and
58 deletions
+183
-58
components/backends/vllm/src/dynamo/vllm/args.py
components/backends/vllm/src/dynamo/vllm/args.py
+19
-1
components/backends/vllm/src/dynamo/vllm/main.py
components/backends/vllm/src/dynamo/vllm/main.py
+2
-0
lib/bindings/python/rust/llm/local_model.rs
lib/bindings/python/rust/llm/local_model.rs
+20
-0
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+12
-0
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+43
-27
lib/llm/src/local_model.rs
lib/llm/src/local_model.rs
+2
-0
lib/llm/src/local_model/runtime_config.rs
lib/llm/src/local_model/runtime_config.rs
+4
-0
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+0
-1
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+16
-0
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+19
-9
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+13
-7
lib/llm/tests/aggregators.rs
lib/llm/tests/aggregators.rs
+28
-13
lib/parsers/src/tool_calling/tools.rs
lib/parsers/src/tool_calling/tools.rs
+5
-0
No files found.
components/backends/vllm/src/dynamo/vllm/args.py
View file @
cbe854fc
...
@@ -58,6 +58,10 @@ class Config:
...
@@ -58,6 +58,10 @@ class Config:
# Connector list from CLI
# Connector list from CLI
connector_list
:
Optional
[
list
]
=
None
connector_list
:
Optional
[
list
]
=
None
# tool and reasoning parser info
tool_call_parser
:
Optional
[
str
]
=
None
reasoning_parser
:
Optional
[
str
]
=
None
def
parse_args
()
->
Config
:
def
parse_args
()
->
Config
:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
...
@@ -102,6 +106,19 @@ def parse_args() -> Config:
...
@@ -102,6 +106,19 @@ def parse_args() -> Config:
help
=
"List of connectors to use in order (e.g., --connector nixl lmcache). "
help
=
"List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector."
,
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector."
,
)
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser
.
add_argument
(
"--dyn-tool-call-parser"
,
type
=
str
,
default
=
None
,
help
=
"Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'."
,
)
parser
.
add_argument
(
"--dyn-reasoning-parser"
,
type
=
str
,
default
=
None
,
help
=
"Reasoning parser name for the model."
,
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -151,7 +168,8 @@ def parse_args() -> Config:
...
@@ -151,7 +168,8 @@ def parse_args() -> Config:
config
.
port_range
=
DynamoPortRange
(
config
.
port_range
=
DynamoPortRange
(
min
=
args
.
dynamo_port_min
,
max
=
args
.
dynamo_port_max
min
=
args
.
dynamo_port_min
,
max
=
args
.
dynamo_port_max
)
)
config
.
tool_call_parser
=
args
.
dyn_tool_call_parser
config
.
reasoning_parser
=
args
.
dyn_reasoning_parser
# Check for conflicting flags
# Check for conflicting flags
has_kv_transfer_config
=
(
has_kv_transfer_config
=
(
hasattr
(
engine_args
,
"kv_transfer_config"
)
hasattr
(
engine_args
,
"kv_transfer_config"
)
...
...
components/backends/vllm/src/dynamo/vllm/main.py
View file @
cbe854fc
...
@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config
.
total_kv_blocks
=
runtime_values
[
"num_gpu_blocks"
]
runtime_config
.
total_kv_blocks
=
runtime_values
[
"num_gpu_blocks"
]
runtime_config
.
max_num_seqs
=
runtime_values
[
"max_num_seqs"
]
runtime_config
.
max_num_seqs
=
runtime_values
[
"max_num_seqs"
]
runtime_config
.
max_num_batched_tokens
=
runtime_values
[
"max_num_batched_tokens"
]
runtime_config
.
max_num_batched_tokens
=
runtime_values
[
"max_num_batched_tokens"
]
runtime_config
.
tool_call_parser
=
config
.
tool_call_parser
runtime_config
.
reasoning_parser
=
config
.
reasoning_parser
await
register_llm
(
await
register_llm
(
ModelType
.
Backend
,
ModelType
.
Backend
,
...
...
lib/bindings/python/rust/llm/local_model.rs
View file @
cbe854fc
...
@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
...
@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
self
.inner.max_num_batched_tokens
=
Some
(
max_num_batched_tokens
);
self
.inner.max_num_batched_tokens
=
Some
(
max_num_batched_tokens
);
}
}
#[setter]
fn
set_tool_call_parser
(
&
mut
self
,
tool_call_parser
:
Option
<
String
>
)
{
self
.inner.tool_call_parser
=
tool_call_parser
;
}
#[setter]
fn
set_reasoning_parser
(
&
mut
self
,
reasoning_parser
:
Option
<
String
>
)
{
self
.inner.reasoning_parser
=
reasoning_parser
;
}
fn
set_engine_specific
(
&
mut
self
,
key
:
&
str
,
value
:
String
)
->
PyResult
<
()
>
{
fn
set_engine_specific
(
&
mut
self
,
key
:
&
str
,
value
:
String
)
->
PyResult
<
()
>
{
let
value
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
value
)
.map_err
(
to_pyerr
)
?
;
let
value
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
value
)
.map_err
(
to_pyerr
)
?
;
self
.inner
self
.inner
...
@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
...
@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
self
.inner.max_num_batched_tokens
self
.inner.max_num_batched_tokens
}
}
#[getter]
fn
tool_call_parser
(
&
self
)
->
Option
<
String
>
{
self
.inner.tool_call_parser
.clone
()
}
#[getter]
fn
reasoning_parser
(
&
self
)
->
Option
<
String
>
{
self
.inner.reasoning_parser
.clone
()
}
#[getter]
#[getter]
fn
runtime_data
(
&
self
,
py
:
Python
<
'_
>
)
->
PyResult
<
PyObject
>
{
fn
runtime_data
(
&
self
,
py
:
Python
<
'_
>
)
->
PyResult
<
PyObject
>
{
let
dict
=
PyDict
::
new
(
py
);
let
dict
=
PyDict
::
new
(
py
);
...
...
lib/llm/src/discovery/model_manager.rs
View file @
cbe854fc
...
@@ -246,6 +246,18 @@ impl ModelManager {
...
@@ -246,6 +246,18 @@ impl ModelManager {
.insert
(
model_name
.to_string
(),
new_kv_chooser
.clone
());
.insert
(
model_name
.to_string
(),
new_kv_chooser
.clone
());
Ok
(
new_kv_chooser
)
Ok
(
new_kv_chooser
)
}
}
pub
fn
get_model_tool_call_parser
(
&
self
,
model
:
&
str
)
->
Option
<
String
>
{
match
self
.entries
.lock
()
{
Ok
(
entries
)
=>
entries
.values
()
.find
(|
entry
|
entry
.name
==
model
)
.and_then
(|
entry
|
entry
.runtime_config
.as_ref
())
.and_then
(|
config
|
config
.tool_call_parser
.clone
())
.map
(|
parser
|
parser
.to_string
()),
Err
(
_
)
=>
None
,
}
}
}
}
pub
struct
ModelEngines
<
E
>
{
pub
struct
ModelEngines
<
E
>
{
...
...
lib/llm/src/http/service/openai.rs
View file @
cbe854fc
...
@@ -37,6 +37,7 @@ use crate::protocols::openai::{
...
@@ -37,6 +37,7 @@ use crate::protocols::openai::{
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
responses
::{
NvCreateResponse
,
NvResponse
},
responses
::{
NvCreateResponse
,
NvResponse
},
ParsingOptions
,
};
};
use
crate
::
request_template
::
RequestTemplate
;
use
crate
::
request_template
::
RequestTemplate
;
use
crate
::
types
::
Annotated
;
use
crate
::
types
::
Annotated
;
...
@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
...
@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
uuid
.to_string
()
uuid
.to_string
()
}
}
fn
get_parsing_options
(
state
:
&
Arc
<
service_v2
::
State
>
,
model
:
&
str
)
->
ParsingOptions
{
let
tool_call_parser
=
state
.manager
()
.get_model_tool_call_parser
(
model
);
let
reasoning_parser
=
None
;
// TODO: Implement reasoning parser
ParsingOptions
::
new
(
tool_call_parser
,
reasoning_parser
)
}
/// OpenAI Completions Request Handler
/// OpenAI Completions Request Handler
///
///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
...
@@ -267,6 +275,8 @@ async fn completions(
...
@@ -267,6 +275,8 @@ async fn completions(
.get_completions_engine
(
model
)
.get_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
let
mut
inflight_guard
=
state
state
.metrics_clone
()
.metrics_clone
()
...
@@ -325,7 +335,7 @@ async fn completions(
...
@@ -325,7 +335,7 @@ async fn completions(
process_metrics_only
(
response
,
&
mut
response_collector
);
process_metrics_only
(
response
,
&
mut
response_collector
);
});
});
let
response
=
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
)
let
response
=
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
.await
.map_err
(|
e
|
{
.map_err
(|
e
|
{
tracing
::
error!
(
tracing
::
error!
(
...
@@ -494,6 +504,8 @@ async fn chat_completions(
...
@@ -494,6 +504,8 @@ async fn chat_completions(
.get_chat_completions_engine
(
model
)
.get_chat_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
let
mut
inflight_guard
=
state
state
.metrics_clone
()
.metrics_clone
()
...
@@ -553,19 +565,20 @@ async fn chat_completions(
...
@@ -553,19 +565,20 @@ async fn chat_completions(
process_metrics_only
(
response
,
&
mut
response_collector
);
process_metrics_only
(
response
,
&
mut
response_collector
);
});
});
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
let
response
=
.await
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
.clone
())
.map_err
(|
e
|
{
.await
tracing
::
error!
(
.map_err
(|
e
|
{
request_id
,
tracing
::
error!
(
"Failed to fold chat completions stream for: {:?}"
,
request_id
,
e
"Failed to fold chat completions stream for: {:?}"
,
);
e
ErrorMessage
::
internal_server_error
(
&
format!
(
);
"Failed to fold chat completions stream: {}"
,
ErrorMessage
::
internal_server_error
(
&
format!
(
e
"Failed to fold chat completions stream: {}"
,
))
e
})
?
;
))
})
?
;
inflight_guard
.mark_ok
();
inflight_guard
.mark_ok
();
Ok
(
Json
(
response
)
.into_response
())
Ok
(
Json
(
response
)
.into_response
())
...
@@ -726,6 +739,8 @@ async fn responses(
...
@@ -726,6 +739,8 @@ async fn responses(
.get_chat_completions_engine
(
model
)
.get_chat_completions_engine
(
model
)
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
.map_err
(|
_
|
ErrorMessage
::
model_not_found
())
?
;
let
parsing_options
=
get_parsing_options
(
&
state
,
model
);
let
mut
inflight_guard
=
let
mut
inflight_guard
=
state
state
.metrics_clone
()
.metrics_clone
()
...
@@ -742,19 +757,20 @@ async fn responses(
...
@@ -742,19 +757,20 @@ async fn responses(
.map_err
(|
e
|
ErrorMessage
::
from_anyhow
(
e
,
"Failed to generate completions"
))
?
;
.map_err
(|
e
|
ErrorMessage
::
from_anyhow
(
e
,
"Failed to generate completions"
))
?
;
// TODO: handle streaming, currently just unary
// TODO: handle streaming, currently just unary
let
response
=
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
let
response
=
.await
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
.clone
())
.map_err
(|
e
|
{
.await
tracing
::
error!
(
.map_err
(|
e
|
{
request_id
,
tracing
::
error!
(
"Failed to fold chat completions stream for: {:?}"
,
request_id
,
e
"Failed to fold chat completions stream for: {:?}"
,
);
e
ErrorMessage
::
internal_server_error
(
&
format!
(
);
"Failed to fold chat completions stream: {}"
,
ErrorMessage
::
internal_server_error
(
&
format!
(
e
"Failed to fold chat completions stream: {}"
,
))
e
})
?
;
))
})
?
;
// Convert NvCreateChatCompletionResponse --> NvResponse
// Convert NvCreateChatCompletionResponse --> NvResponse
let
response
:
NvResponse
=
response
.try_into
()
.map_err
(|
e
|
{
let
response
:
NvResponse
=
response
.try_into
()
.map_err
(|
e
|
{
...
...
lib/llm/src/local_model.rs
View file @
cbe854fc
...
@@ -202,6 +202,7 @@ impl LocalModelBuilder {
...
@@ -202,6 +202,7 @@ impl LocalModelBuilder {
);
);
card
.migration_limit
=
self
.migration_limit
;
card
.migration_limit
=
self
.migration_limit
;
card
.user_data
=
self
.user_data
.take
();
card
.user_data
=
self
.user_data
.take
();
return
Ok
(
LocalModel
{
return
Ok
(
LocalModel
{
card
,
card
,
full_path
:
PathBuf
::
new
(),
full_path
:
PathBuf
::
new
(),
...
@@ -392,6 +393,7 @@ impl LocalModel {
...
@@ -392,6 +393,7 @@ impl LocalModel {
let
kvstore
:
Box
<
dyn
KeyValueStore
>
=
Box
::
new
(
EtcdStorage
::
new
(
etcd_client
.clone
()));
let
kvstore
:
Box
<
dyn
KeyValueStore
>
=
Box
::
new
(
EtcdStorage
::
new
(
etcd_client
.clone
()));
let
card_store
=
Arc
::
new
(
KeyValueStoreManager
::
new
(
kvstore
));
let
card_store
=
Arc
::
new
(
KeyValueStoreManager
::
new
(
kvstore
));
let
key
=
self
.card
.slug
()
.to_string
();
let
key
=
self
.card
.slug
()
.to_string
();
card_store
card_store
.publish
(
model_card
::
ROOT_PATH
,
None
,
&
key
,
&
mut
self
.card
)
.publish
(
model_card
::
ROOT_PATH
,
None
,
&
key
,
&
mut
self
.card
)
.await
?
;
.await
?
;
...
...
lib/llm/src/local_model/runtime_config.rs
View file @
cbe854fc
...
@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
...
@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
pub
max_num_batched_tokens
:
Option
<
u64
>
,
pub
max_num_batched_tokens
:
Option
<
u64
>
,
pub
tool_call_parser
:
Option
<
String
>
,
pub
reasoning_parser
:
Option
<
String
>
,
/// Mapping of engine-specific runtime configs
/// Mapping of engine-specific runtime configs
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
...
...
lib/llm/src/preprocessor.rs
View file @
cbe854fc
...
@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
...
@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
let
mdcsum
=
mdc
.mdcsum
();
let
mdcsum
=
mdc
.mdcsum
();
let
formatter
=
PromptFormatter
::
from_mdc
(
mdc
.clone
())
.await
?
;
let
formatter
=
PromptFormatter
::
from_mdc
(
mdc
.clone
())
.await
?
;
let
PromptFormatter
::
OAI
(
formatter
)
=
formatter
;
let
PromptFormatter
::
OAI
(
formatter
)
=
formatter
;
let
tokenizer
=
match
&
mdc
.tokenizer
{
let
tokenizer
=
match
&
mdc
.tokenizer
{
Some
(
TokenizerKind
::
HfTokenizerJson
(
file
))
=>
HuggingFaceTokenizer
::
from_file
(
file
)
?
,
Some
(
TokenizerKind
::
HfTokenizerJson
(
file
))
=>
HuggingFaceTokenizer
::
from_file
(
file
)
?
,
Some
(
TokenizerKind
::
GGUF
(
tokenizer
))
=>
{
Some
(
TokenizerKind
::
GGUF
(
tokenizer
))
=>
{
...
...
lib/llm/src/protocols/openai.rs
View file @
cbe854fc
...
@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
...
@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length).
/// Gets the current prompt token count (Input Sequence Length).
fn
get_isl
(
&
self
)
->
Option
<
u32
>
;
fn
get_isl
(
&
self
)
->
Option
<
u32
>
;
}
}
#[derive(Clone,
Debug,
Serialize,
Deserialize,
Default)]
pub
struct
ParsingOptions
{
pub
tool_call_parser
:
Option
<
String
>
,
pub
reasoning_parser
:
Option
<
String
>
,
}
impl
ParsingOptions
{
pub
fn
new
(
tool_call_parser
:
Option
<
String
>
,
reasoning_parser
:
Option
<
String
>
)
->
Self
{
Self
{
tool_call_parser
,
reasoning_parser
,
}
}
}
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
cbe854fc
...
@@ -19,7 +19,9 @@ use std::collections::HashMap;
...
@@ -19,7 +19,9 @@ use std::collections::HashMap;
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
use
crate
::
protocols
::{
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
convert_sse_stream
,
openai
::
ParsingOptions
,
Annotated
,
};
};
use
dynamo_parsers
::
tool_calling
::
try_tool_call_parse_aggregate
;
use
dynamo_parsers
::
tool_calling
::
try_tool_call_parse_aggregate
;
...
@@ -99,6 +101,7 @@ impl DeltaAggregator {
...
@@ -99,6 +101,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
pub
async
fn
apply
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
let
aggregator
=
stream
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
...
@@ -175,7 +178,10 @@ impl DeltaAggregator {
...
@@ -175,7 +178,10 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
// After aggregation, inspect each choice's text for tool call syntax
for
choice
in
aggregator
.choices
.values_mut
()
{
for
choice
in
aggregator
.choices
.values_mut
()
{
if
choice
.tool_calls
.is_none
()
{
if
choice
.tool_calls
.is_none
()
{
if
let
Ok
(
tool_calls
)
=
try_tool_call_parse_aggregate
(
&
choice
.text
,
None
)
{
if
let
Ok
(
tool_calls
)
=
try_tool_call_parse_aggregate
(
&
choice
.text
,
parsing_options
.tool_call_parser
.as_deref
(),
)
{
if
tool_calls
.is_empty
()
{
if
tool_calls
.is_empty
()
{
continue
;
continue
;
}
}
...
@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
...
@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
/// * `Err(String)` if an error occurs.
async
fn
from_annotated_stream
(
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
...
@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
...
@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
/// * `Err(String)` if an error occurs.
async
fn
from_sse_stream
(
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
;
}
}
impl
ChatCompletionAggregator
for
dynamo_async_openai
::
types
::
CreateChatCompletionResponse
{
impl
ChatCompletionAggregator
for
dynamo_async_openai
::
types
::
CreateChatCompletionResponse
{
async
fn
from_annotated_stream
(
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
DeltaAggregator
::
apply
(
stream
,
parsing_options
)
.await
}
}
async
fn
from_sse_stream
(
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateChatCompletionStreamResponse
>
(
stream
);
let
stream
=
convert_sse_stream
::
<
NvCreateChatCompletionStreamResponse
>
(
stream
);
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
)
.await
NvCreateChatCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
}
}
}
}
...
@@ -347,7 +357,7 @@ mod tests {
...
@@ -347,7 +357,7 @@ mod tests {
Box
::
pin
(
stream
::
empty
());
Box
::
pin
(
stream
::
empty
());
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -377,7 +387,7 @@ mod tests {
...
@@ -377,7 +387,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -421,7 +431,7 @@ mod tests {
...
@@ -421,7 +431,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -492,7 +502,7 @@ mod tests {
...
@@ -492,7 +502,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -550,7 +560,7 @@ mod tests {
...
@@ -550,7 +560,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
...
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
cbe854fc
...
@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse;
...
@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse;
use
crate
::
protocols
::{
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
codec
::{
Message
,
SseCodecError
},
common
::
FinishReason
,
common
::
FinishReason
,
convert_sse_stream
,
Annotated
,
DataStream
,
convert_sse_stream
,
openai
::
ParsingOptions
,
Annotated
,
DataStream
,
};
};
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
...
@@ -65,7 +67,9 @@ impl DeltaAggregator {
...
@@ -65,7 +67,9 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub
async
fn
apply
(
pub
async
fn
apply
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
)
->
Result
<
NvCreateCompletionResponse
>
{
tracing
::
debug!
(
"Tool Call Parser: {:?}"
,
parsing_options
.tool_call_parser
);
// TODO: remove this once completion has tool call support
let
aggregator
=
stream
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
let
delta
=
match
delta
.ok
()
{
let
delta
=
match
delta
.ok
()
{
...
@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
...
@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
impl
NvCreateCompletionResponse
{
impl
NvCreateCompletionResponse
{
pub
async
fn
from_sse_stream
(
pub
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
)
->
Result
<
NvCreateCompletionResponse
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateCompletionResponse
>
(
stream
);
let
stream
=
convert_sse_stream
::
<
NvCreateCompletionResponse
>
(
stream
);
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
)
.await
NvCreateCompletionResponse
::
from_annotated_stream
(
stream
,
parsing_options
)
.await
}
}
pub
async
fn
from_annotated_stream
(
pub
async
fn
from_annotated_stream
(
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
parsing_options
:
ParsingOptions
,
)
->
Result
<
NvCreateCompletionResponse
>
{
)
->
Result
<
NvCreateCompletionResponse
>
{
DeltaAggregator
::
apply
(
stream
)
.await
DeltaAggregator
::
apply
(
stream
,
parsing_options
)
.await
}
}
}
}
...
@@ -241,7 +247,7 @@ mod tests {
...
@@ -241,7 +247,7 @@ mod tests {
let
stream
:
DataStream
<
Annotated
<
NvCreateCompletionResponse
>>
=
Box
::
pin
(
stream
::
empty
());
let
stream
:
DataStream
<
Annotated
<
NvCreateCompletionResponse
>>
=
Box
::
pin
(
stream
::
empty
());
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -265,7 +271,7 @@ mod tests {
...
@@ -265,7 +271,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -305,7 +311,7 @@ mod tests {
...
@@ -305,7 +311,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
let
stream
=
Box
::
pin
(
stream
::
iter
(
annotated_deltas
));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
@@ -365,7 +371,7 @@ mod tests {
...
@@ -365,7 +371,7 @@ mod tests {
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
,
ParsingOptions
::
default
()
)
.await
;
// Check the result
// Check the result
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
...
...
lib/llm/tests/aggregators.rs
View file @
cbe854fc
...
@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{
...
@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{
openai
::{
openai
::{
chat_completions
::{
aggregator
::
ChatCompletionAggregator
,
NvCreateChatCompletionResponse
},
chat_completions
::{
aggregator
::
ChatCompletionAggregator
,
NvCreateChatCompletionResponse
},
completions
::
NvCreateCompletionResponse
,
completions
::
NvCreateCompletionResponse
,
ParsingOptions
,
},
},
ContentProvider
,
DataStream
,
ContentProvider
,
DataStream
,
};
};
...
@@ -37,9 +38,12 @@ async fn test_openai_chat_stream() {
...
@@ -37,9 +38,12 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small
// note: we are only taking the first 16 messages to keep the size of the response small
let
stream
=
create_message_stream
(
&
data
)
.take
(
16
);
let
stream
=
create_message_stream
(
&
data
)
.take
(
16
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
.await
Box
::
pin
(
stream
),
.unwrap
();
ParsingOptions
::
default
(),
)
.await
.unwrap
();
// todo: provide a cleaner way to extract the content from choices
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
assert_eq!
(
...
@@ -59,9 +63,12 @@ async fn test_openai_chat_stream() {
...
@@ -59,9 +63,12 @@ async fn test_openai_chat_stream() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_chat_edge_case_multi_line_data
()
{
async
fn
test_openai_chat_edge_case_multi_line_data
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-multi-line-data"
);
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-multi-line-data"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
.await
Box
::
pin
(
stream
),
.unwrap
();
ParsingOptions
::
default
(),
)
.await
.unwrap
();
assert_eq!
(
assert_eq!
(
result
result
...
@@ -79,9 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() {
...
@@ -79,9 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_chat_edge_case_comments_per_response
()
{
async
fn
test_openai_chat_edge_case_comments_per_response
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-comments_per_response"
);
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/valid-comments_per_response"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
.await
Box
::
pin
(
stream
),
.unwrap
();
ParsingOptions
::
default
(),
)
.await
.unwrap
();
assert_eq!
(
assert_eq!
(
result
result
...
@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() {
...
@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_chat_edge_case_invalid_deserialize_error
()
{
async
fn
test_openai_chat_edge_case_invalid_deserialize_error
()
{
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/invalid-deserialize_error"
);
let
stream
=
create_stream
(
CHAT_ROOT_PATH
,
"edge_cases/invalid-deserialize_error"
);
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
;
let
result
=
NvCreateChatCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
(),
)
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
// insta::assert_debug_snapshot!(result);
// insta::assert_debug_snapshot!(result);
...
@@ -112,9 +126,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
...
@@ -112,9 +126,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_cmpl_stream
()
{
async
fn
test_openai_cmpl_stream
()
{
let
stream
=
create_stream
(
CMPL_ROOT_PATH
,
"completion.streaming.1"
)
.take
(
16
);
let
stream
=
create_stream
(
CMPL_ROOT_PATH
,
"completion.streaming.1"
)
.take
(
16
);
let
result
=
NvCreateCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
let
result
=
.await
NvCreateCompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
),
ParsingOptions
::
default
())
.unwrap
();
.await
.unwrap
();
// todo: provide a cleaner way to extract the content from choices
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
assert_eq!
(
...
...
lib/parsers/src/tool_calling/tools.rs
View file @
cbe854fc
...
@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate(
...
@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate(
message
:
&
str
,
message
:
&
str
,
parser_str
:
Option
<&
str
>
,
parser_str
:
Option
<&
str
>
,
)
->
anyhow
::
Result
<
Vec
<
dynamo_async_openai
::
types
::
ChatCompletionMessageToolCall
>>
{
)
->
anyhow
::
Result
<
Vec
<
dynamo_async_openai
::
types
::
ChatCompletionMessageToolCall
>>
{
if
parser_str
.is_none
()
{
tracing
::
info!
(
"No tool parser provided. Trying parsing with default parser."
);
}
else
{
tracing
::
info!
(
"Using tool parser: {:?}"
,
parser_str
);
}
let
parsed
=
detect_and_parse_tool_call
(
message
,
parser_str
)
?
;
let
parsed
=
detect_and_parse_tool_call
(
message
,
parser_str
)
?
;
if
parsed
.is_empty
()
{
if
parsed
.is_empty
()
{
return
Ok
(
vec!
[]);
return
Ok
(
vec!
[]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment