Unverified Commit cbe854fc authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: [vLLM] implement cli args for tool and reasoning parsers (#2619)

parent b658ba61
......@@ -58,6 +58,10 @@ class Config:
# Connector list from CLI
connector_list: Optional[list] = None
# tool and reasoning parser info
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
def parse_args() -> Config:
parser = FlexibleArgumentParser(
......@@ -102,6 +106,19 @@ def parse_args() -> Config:
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
......@@ -151,7 +168,8 @@ def parse_args() -> Config:
config.port_range = DynamoPortRange(
min=args.dynamo_port_min, max=args.dynamo_port_max
)
config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
# Check for conflicting flags
has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config")
......
......@@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
await register_llm(
ModelType.Backend,
......
......@@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
}
#[setter]
fn set_tool_call_parser(&mut self, tool_call_parser: Option<String>) {
self.inner.tool_call_parser = tool_call_parser;
}
#[setter]
fn set_reasoning_parser(&mut self, reasoning_parser: Option<String>) {
self.inner.reasoning_parser = reasoning_parser;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......@@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens
}
#[getter]
fn tool_call_parser(&self) -> Option<String> {
self.inner.tool_call_parser.clone()
}
#[getter]
fn reasoning_parser(&self) -> Option<String> {
self.inner.reasoning_parser.clone()
}
#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
......
......@@ -246,6 +246,18 @@ impl ModelManager {
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
}
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
match self.entries.lock() {
Ok(entries) => entries
.values()
.find(|entry| entry.name == model)
.and_then(|entry| entry.runtime_config.as_ref())
.and_then(|config| config.tool_call_parser.clone())
.map(|parser| parser.to_string()),
Err(_) => None,
}
}
}
pub struct ModelEngines<E> {
......
......@@ -37,6 +37,7 @@ use crate::protocols::openai::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse},
ParsingOptions,
};
use crate::request_template::RequestTemplate;
use crate::types::Annotated;
......@@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
uuid.to_string()
}
fn get_parsing_options(state: &Arc<service_v2::State>, model: &str) -> ParsingOptions {
let tool_call_parser = state.manager().get_model_tool_call_parser(model);
let reasoning_parser = None; // TODO: Implement reasoning parser
ParsingOptions::new(tool_call_parser, reasoning_parser)
}
/// OpenAI Completions Request Handler
///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
......@@ -267,6 +275,8 @@ async fn completions(
.get_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard =
state
.metrics_clone()
......@@ -325,7 +335,7 @@ async fn completions(
process_metrics_only(response, &mut response_collector);
});
let response = NvCreateCompletionResponse::from_annotated_stream(stream)
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
......@@ -494,6 +504,8 @@ async fn chat_completions(
.get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard =
state
.metrics_clone()
......@@ -553,19 +565,20 @@ async fn chat_completions(
process_metrics_only(response, &mut response_collector);
});
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
inflight_guard.mark_ok();
Ok(Json(response).into_response())
......@@ -726,6 +739,8 @@ async fn responses(
.get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = get_parsing_options(&state, model);
let mut inflight_guard =
state
.metrics_clone()
......@@ -742,19 +757,20 @@ async fn responses(
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
// Convert NvCreateChatCompletionResponse --> NvResponse
let response: NvResponse = response.try_into().map_err(|e| {
......
......@@ -202,6 +202,7 @@ impl LocalModelBuilder {
);
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
return Ok(LocalModel {
card,
full_path: PathBuf::new(),
......@@ -392,6 +393,7 @@ impl LocalModel {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();
card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
......
......@@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {
pub max_num_batched_tokens: Option<u64>,
pub tool_call_parser: Option<String>,
pub reasoning_parser: Option<String>,
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
......
......@@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
let mdcsum = mdc.mdcsum();
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
let PromptFormatter::OAI(formatter) = formatter;
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => {
......
......@@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>;
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ParsingOptions {
pub tool_call_parser: Option<String>,
pub reasoning_parser: Option<String>,
}
impl ParsingOptions {
pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
Self {
tool_call_parser,
reasoning_parser,
}
}
}
......@@ -19,7 +19,9 @@ use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
convert_sse_stream,
openai::ParsingOptions,
Annotated,
};
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
......@@ -99,6 +101,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
......@@ -175,7 +178,10 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
) {
if tool_calls.is_empty() {
continue;
}
......@@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>;
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
......@@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>;
}
impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
DeltaAggregator::apply(stream, parsing_options).await
}
async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await
}
}
......@@ -347,7 +357,7 @@ mod tests {
Box::pin(stream::empty());
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -377,7 +387,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -421,7 +431,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -492,7 +502,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -550,7 +560,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......
......@@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse;
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
convert_sse_stream, Annotated, DataStream,
convert_sse_stream,
openai::ParsingOptions,
Annotated, DataStream,
};
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
......@@ -65,7 +67,9 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> {
tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() {
......@@ -177,15 +181,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
impl NvCreateCompletionResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> {
let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
NvCreateCompletionResponse::from_annotated_stream(stream).await
NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
}
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await
DeltaAggregator::apply(stream, parsing_options).await
}
}
......@@ -241,7 +247,7 @@ mod tests {
let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -265,7 +271,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -305,7 +311,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......@@ -365,7 +371,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
......
......@@ -18,6 +18,7 @@ use dynamo_llm::protocols::{
openai::{
chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse},
completions::NvCreateCompletionResponse,
ParsingOptions,
},
ContentProvider, DataStream,
};
......@@ -37,9 +38,12 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small
let stream = create_message_stream(&data).take(16);
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await
.unwrap();
// todo: provide a cleaner way to extract the content from choices
assert_eq!(
......@@ -59,9 +63,12 @@ async fn test_openai_chat_stream() {
#[tokio::test]
async fn test_openai_chat_edge_case_multi_line_data() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await
.unwrap();
assert_eq!(
result
......@@ -79,9 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test]
async fn test_openai_chat_edge_case_comments_per_response() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await
.unwrap();
assert_eq!(
result
......@@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test]
async fn test_openai_chat_edge_case_invalid_deserialize_error() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error");
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await;
let result = NvCreateChatCompletionResponse::from_sse_stream(
Box::pin(stream),
ParsingOptions::default(),
)
.await;
assert!(result.is_err());
// insta::assert_debug_snapshot!(result);
......@@ -112,9 +126,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test]
async fn test_openai_cmpl_stream() {
let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16);
let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
let result =
NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), ParsingOptions::default())
.await
.unwrap();
// todo: provide a cleaner way to extract the content from choices
assert_eq!(
......
......@@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>> {
if parser_str.is_none() {
tracing::info!("No tool parser provided. Trying parsing with default parser.");
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let parsed = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() {
return Ok(vec![]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment