Unverified Commit cbe854fc authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)

parent b658ba61
...@@ -58,6 +58,10 @@ class Config: ...@@ -58,6 +58,10 @@ class Config:
# Connector list from CLI # Connector list from CLI
connector_list: Optional[list] = None connector_list: Optional[list] = None
# tool and reasoning parser info
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
def parse_args() -> Config: def parse_args() -> Config:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
...@@ -102,6 +106,19 @@ def parse_args() -> Config: ...@@ -102,6 +106,19 @@ def parse_args() -> Config:
help="List of connectors to use in order (e.g., --connector nixl lmcache). " help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.", "Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
) )
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -151,7 +168,8 @@ def parse_args() -> Config: ...@@ -151,7 +168,8 @@ def parse_args() -> Config:
config.port_range = DynamoPortRange( config.port_range = DynamoPortRange(
min=args.dynamo_port_min, max=args.dynamo_port_max min=args.dynamo_port_min, max=args.dynamo_port_max
) )
config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
# Check for conflicting flags # Check for conflicting flags
has_kv_transfer_config = ( has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config") hasattr(engine_args, "kv_transfer_config")
......
...@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
await register_llm( await register_llm(
ModelType.Backend, ModelType.Backend,
......
...@@ -34,6 +34,16 @@ impl ModelRuntimeConfig { ...@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens); self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
} }
#[setter]
fn set_tool_call_parser(&mut self, tool_call_parser: Option<String>) {
self.inner.tool_call_parser = tool_call_parser;
}
#[setter]
fn set_reasoning_parser(&mut self, reasoning_parser: Option<String>) {
self.inner.reasoning_parser = reasoning_parser;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner self.inner
...@@ -57,6 +67,16 @@ impl ModelRuntimeConfig { ...@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens self.inner.max_num_batched_tokens
} }
#[getter]
fn tool_call_parser(&self) -> Option<String> {
self.inner.tool_call_parser.clone()
}
#[getter]
fn reasoning_parser(&self) -> Option<String> {
self.inner.reasoning_parser.clone()
}
#[getter] #[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> { fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
......
...@@ -246,6 +246,18 @@ impl ModelManager { ...@@ -246,6 +246,18 @@ impl ModelManager {
.insert(model_name.to_string(), new_kv_chooser.clone()); .insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser) Ok(new_kv_chooser)
} }
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
match self.entries.lock() {
Ok(entries) => entries
.values()
.find(|entry| entry.name == model)
.and_then(|entry| entry.runtime_config.as_ref())
.and_then(|config| config.tool_call_parser.clone())
.map(|parser| parser.to_string()),
Err(_) => None,
}
}
} }
pub struct ModelEngines<E> { pub struct ModelEngines<E> {
......
...@@ -37,6 +37,7 @@ use crate::protocols::openai::{ ...@@ -37,6 +37,7 @@ use crate::protocols::openai::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse}, responses::{NvCreateResponse, NvResponse},
ParsingOptions,
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
...@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin ...@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
uuid.to_string() uuid.to_string()
} }
fn get_parsing_options(state: &Arc<service_v2::State>, model: &str) -> ParsingOptions {
let tool_call_parser = state.manager().get_model_tool_call_parser(model);
let reasoning_parser = None; // TODO: Implement reasoning parser
ParsingOptions::new(tool_call_parser, reasoning_parser)
}
/// OpenAI Completions Request Handler /// OpenAI Completions Request Handler
/// ///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source" /// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
...@@ -267,6 +275,8 @@ async fn completions( ...@@ -267,6 +275,8 @@ async fn completions(
.get_completions_engine(model) .get_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard = let mut inflight_guard =
state state
.metrics_clone() .metrics_clone()
...@@ -325,7 +335,7 @@ async fn completions( ...@@ -325,7 +335,7 @@ async fn completions(
process_metrics_only(response, &mut response_collector); process_metrics_only(response, &mut response_collector);
}); });
let response = NvCreateCompletionResponse::from_annotated_stream(stream) let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -494,6 +504,8 @@ async fn chat_completions( ...@@ -494,6 +504,8 @@ async fn chat_completions(
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard = let mut inflight_guard =
state state
.metrics_clone() .metrics_clone()
...@@ -553,7 +565,8 @@ async fn chat_completions( ...@@ -553,7 +565,8 @@ async fn chat_completions(
process_metrics_only(response, &mut response_collector); process_metrics_only(response, &mut response_collector);
}); });
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -726,6 +739,8 @@ async fn responses( ...@@ -726,6 +739,8 @@ async fn responses(
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard = let mut inflight_guard =
state state
.metrics_clone() .metrics_clone()
...@@ -742,7 +757,8 @@ async fn responses( ...@@ -742,7 +757,8 @@ async fn responses(
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?; .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// TODO: handle streaming, currently just unary // TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
......
...@@ -202,6 +202,7 @@ impl LocalModelBuilder { ...@@ -202,6 +202,7 @@ impl LocalModelBuilder {
); );
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
return Ok(LocalModel { return Ok(LocalModel {
card, card,
full_path: PathBuf::new(), full_path: PathBuf::new(),
...@@ -392,6 +393,7 @@ impl LocalModel { ...@@ -392,6 +393,7 @@ impl LocalModel {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string(); let key = self.card.slug().to_string();
card_store card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card) .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?; .await?;
......
...@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig { ...@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
pub max_num_batched_tokens: Option<u64>, pub max_num_batched_tokens: Option<u64>,
pub tool_call_parser: Option<String>,
pub reasoning_parser: Option<String>,
/// Mapping of engine-specific runtime configs /// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")] #[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>, pub runtime_data: HashMap<String, serde_json::Value>,
......
...@@ -101,7 +101,6 @@ impl OpenAIPreprocessor { ...@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
let mdcsum = mdc.mdcsum(); let mdcsum = mdc.mdcsum();
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
let PromptFormatter::OAI(formatter) = formatter; let PromptFormatter::OAI(formatter) = formatter;
let tokenizer = match &mdc.tokenizer { let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?, Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => { Some(TokenizerKind::GGUF(tokenizer)) => {
......
...@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>: ...@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length). /// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>; fn get_isl(&self) -> Option<u32>;
} }
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ParsingOptions {
pub tool_call_parser: Option<String>,
pub reasoning_parser: Option<String>,
}
impl ParsingOptions {
pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
Self {
tool_call_parser,
reasoning_parser,
}
}
}
...@@ -19,7 +19,9 @@ use std::collections::HashMap; ...@@ -19,7 +19,9 @@ use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream,
openai::ParsingOptions,
Annotated,
}; };
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate; use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
...@@ -99,6 +101,7 @@ impl DeltaAggregator { ...@@ -99,6 +101,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing. /// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -175,7 +178,10 @@ impl DeltaAggregator { ...@@ -175,7 +178,10 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax // After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() { for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() { if choice.tool_calls.is_none() {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) { if let Ok(tool_calls) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
) {
if tool_calls.is_empty() { if tool_calls.is_empty() {
continue; continue;
} }
...@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator { ...@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
async fn from_annotated_stream( async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>; ) -> Result<NvCreateChatCompletionResponse, String>;
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
...@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator { ...@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
async fn from_sse_stream( async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>; ) -> Result<NvCreateChatCompletionResponse, String>;
} }
impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse { impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
async fn from_annotated_stream( async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream, parsing_options).await
} }
async fn from_sse_stream( async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream); let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await
} }
} }
...@@ -347,7 +357,7 @@ mod tests { ...@@ -347,7 +357,7 @@ mod tests {
Box::pin(stream::empty()); Box::pin(stream::empty());
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -377,7 +387,7 @@ mod tests { ...@@ -377,7 +387,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -421,7 +431,7 @@ mod tests { ...@@ -421,7 +431,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas)); let stream = Box::pin(stream::iter(annotated_deltas));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -492,7 +502,7 @@ mod tests { ...@@ -492,7 +502,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -550,7 +560,7 @@ mod tests { ...@@ -550,7 +560,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
......
...@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse; ...@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason, common::FinishReason,
convert_sse_stream, Annotated, DataStream, convert_sse_stream,
openai::ParsingOptions,
Annotated, DataStream,
}; };
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`]. /// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
...@@ -65,7 +67,9 @@ impl DeltaAggregator { ...@@ -65,7 +67,9 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`]. /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply( pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>, stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() { let delta = match delta.ok() {
...@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice { ...@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
impl NvCreateCompletionResponse { impl NvCreateCompletionResponse {
pub async fn from_sse_stream( pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream); let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
NvCreateCompletionResponse::from_annotated_stream(stream).await NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
} }
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>, stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream, parsing_options).await
} }
} }
...@@ -241,7 +247,7 @@ mod tests { ...@@ -241,7 +247,7 @@ mod tests {
let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty()); let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -265,7 +271,7 @@ mod tests { ...@@ -265,7 +271,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -305,7 +311,7 @@ mod tests { ...@@ -305,7 +311,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas)); let stream = Box::pin(stream::iter(annotated_deltas));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -365,7 +371,7 @@ mod tests { ...@@ -365,7 +371,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result // Check the result
assert!(result.is_ok()); assert!(result.is_ok());
......
...@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{ ...@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{
openai::{ openai::{
chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse}, chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse},
completions::NvCreateCompletionResponse, completions::NvCreateCompletionResponse,
ParsingOptions,
}, },
ContentProvider, DataStream, ContentProvider, DataStream,
}; };
...@@ -37,7 +38,10 @@ async fn test_openai_chat_stream() { ...@@ -37,7 +38,10 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small // note: we are only taking the first 16 messages to keep the size of the response small
let stream = create_message_stream(&data).take(16); let stream = create_message_stream(&data).take(16);
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await .await
.unwrap(); .unwrap();
...@@ -59,7 +63,10 @@ async fn test_openai_chat_stream() { ...@@ -59,7 +63,10 @@ async fn test_openai_chat_stream() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_multi_line_data() { async fn test_openai_chat_edge_case_multi_line_data() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await .await
.unwrap(); .unwrap();
...@@ -79,7 +86,10 @@ async fn test_openai_chat_edge_case_multi_line_data() { ...@@ -79,7 +86,10 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_comments_per_response() { async fn test_openai_chat_edge_case_comments_per_response() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await .await
.unwrap(); .unwrap();
...@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() { ...@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_invalid_deserialize_error() { async fn test_openai_chat_edge_case_invalid_deserialize_error() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await; let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await;
assert!(result.is_err()); assert!(result.is_err());
// insta::assert_debug_snapshot!(result); // insta::assert_debug_snapshot!(result);
...@@ -112,7 +126,8 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { ...@@ -112,7 +126,8 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test] #[tokio::test]
async fn test_openai_cmpl_stream() { async fn test_openai_cmpl_stream() {
let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16);
let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream)) let result =
NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), ParsingOptions::default())
.await .await
.unwrap(); .unwrap();
......
...@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate( ...@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>> { ) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>> {
if parser_str.is_none() {
tracing::info!("No tool parser provided. Trying parsing with default parser.");
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let parsed = detect_and_parse_tool_call(message, parser_str)?; let parsed = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok(vec![]); return Ok(vec![]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment