Unverified Commit f5a41004 authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: enable --dyn-reasoning-parser flag to set reasoning parser for vllm deployments (#2700)

parent 68fb3d95
......@@ -117,7 +117,7 @@ def parse_args() -> Config:
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model.",
help="Reasoning parser name for the model. Available options: 'basic', 'deepseek_r1', 'gpt_oss'.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
......
......@@ -14,6 +14,7 @@ use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use crate::backend::ExecutionContext;
use crate::local_model::runtime_config;
use crate::preprocessor::PreprocessedRequest;
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::openai::{
......@@ -183,7 +184,7 @@ impl
incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = incoming_request.transfer(());
let mut deltas = request.response_generator();
let mut deltas = request.response_generator(runtime_config::ModelRuntimeConfig::default());
let ctx = context.context();
let req = request.inner.messages.into_iter().next_back().unwrap();
......
......@@ -202,6 +202,7 @@ impl LocalModelBuilder {
);
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
return Ok(LocalModel {
card,
......@@ -276,6 +277,7 @@ impl LocalModelBuilder {
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
Ok(LocalModel {
card,
......
......@@ -19,6 +19,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
......@@ -137,6 +138,9 @@ pub struct ModelDeploymentCard {
/// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>,
#[serde(default)]
pub runtime_config: ModelRuntimeConfig,
}
impl ModelDeploymentCard {
......@@ -441,6 +445,7 @@ impl ModelDeploymentCard {
kv_cache_block_size: 0,
migration_limit: 0,
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
})
}
......@@ -482,6 +487,7 @@ impl ModelDeploymentCard {
kv_cache_block_size: 0, // set later
migration_limit: 0,
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
})
}
}
......
......@@ -22,6 +22,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{collections::HashMap, sync::Arc};
use tracing;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::tokenizers::Encoding;
......@@ -94,6 +95,7 @@ pub struct OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>,
model_info: Arc<dyn ModelInfo>,
runtime_config: ModelRuntimeConfig,
}
impl OpenAIPreprocessor {
......@@ -121,11 +123,14 @@ impl OpenAIPreprocessor {
};
let model_info = model_info.get_model_info().await?;
let runtime_config = mdc.runtime_config.clone();
Ok(Arc::new(Self {
formatter,
tokenizer,
model_info,
mdcsum,
runtime_config,
}))
}
......@@ -494,7 +499,7 @@ impl
let (request, context) = request.into_parts();
// create a response generator
let response_generator = request.response_generator();
let response_generator = request.response_generator(self.runtime_config.clone());
let mut response_generator = Box::new(response_generator);
// convert the chat completion request to a common completion request
......
......@@ -5,6 +5,7 @@ use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, Reasoni
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{
local_model::runtime_config,
protocols::common::{self},
types::TokenIdType,
};
......@@ -15,11 +16,15 @@ impl NvCreateChatCompletionRequest {
///
/// # Returns
/// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self) -> DeltaGenerator {
pub fn response_generator(
&self,
runtime_config: runtime_config::ModelRuntimeConfig,
) -> DeltaGenerator {
let options = DeltaGeneratorOptions {
enable_usage: true,
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
runtime_config,
};
DeltaGenerator::new(self.inner.model.clone(), options)
......@@ -33,6 +38,8 @@ pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
pub runtime_config: runtime_config::ModelRuntimeConfig,
}
/// Generates incremental chat completion responses in a streaming fashion.
......@@ -92,10 +99,14 @@ impl DeltaGenerator {
// This is hardcoded for now, but can be made configurable later.
// TODO: Make parser type configurable once front-end integration is determined
// Change to GptOss to test GptOSS parser
let reasoning_parser_type = ReasoningParserType::Basic;
// Reasoning parser wrapper
let reasoning_parser = reasoning_parser_type.get_reasoning_parser();
let reasoning_parser = ReasoningParserType::get_reasoning_parser_from_name(
options
.runtime_config
.reasoning_parser
.as_deref()
.unwrap_or("basic"),
);
Self {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
......
......@@ -16,9 +16,20 @@
use anyhow::Error;
use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::http::{
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
......@@ -26,15 +37,8 @@ use dynamo_llm::http::{
metrics::{Endpoint, FRONTEND_METRIC_PREFIX, RequestType, Status},
service_v2::HttpService,
},
};
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
local_model::runtime_config,
};
use dynamo_runtime::{
CancellationToken,
......@@ -95,7 +99,8 @@ impl
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
let mut generator = request.response_generator();
let mut generator =
request.response_generator(runtime_config::ModelRuntimeConfig::default());
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
......
......@@ -115,4 +115,20 @@ impl ReasoningParserType {
},
}
}
pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
tracing::debug!("Selected reasoning parser: {}", name);
match name.to_lowercase().as_str() {
"deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
"basic" => Self::Basic.get_reasoning_parser(),
"gpt_oss" => Self::GptOss.get_reasoning_parser(),
_ => {
tracing::warn!(
"Unknown reasoning parser type '{}', falling back to Basic Reasoning Parser",
name
);
Self::Basic.get_reasoning_parser()
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment