Unverified Commit adad2ecd authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

feat: Add request template support for default inference parameters (#841)

Adds support for specifying default request parameters through a json template file that can be applied across all inference requests. This enables consistent parameter settings while still allowing per-request overrides.

Changes:
- Add --request-template CLI flag to specify template file path
- Integrate template support in HTTP, batch and text input modes
- Template values can be overridden by individual request parameters
- Example template.json:
```
{
    "model": "Qwen2.5-3B-Instruct",
    "temperature": 0.7,
    "max_completion_tokens": 4096
}
```
parent 904730b9
...@@ -126,6 +126,17 @@ pub struct Flags { ...@@ -126,6 +126,17 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub extra_engine_args: Option<PathBuf>, pub extra_engine_args: Option<PathBuf>,
/// Path to a JSON file containing default request fields.
/// These fields will be merged with each request, but can be overridden by the request.
/// Example file contents:
/// {
/// "model": "Qwen2.5-3B-Instruct",
/// "temperature": 0.7,
/// "max_completion_tokens": 4096
/// }
#[arg(long)]
pub request_template: Option<PathBuf>,
/// Everything after a `--`. /// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......
...@@ -17,6 +17,7 @@ use anyhow::Context as _; ...@@ -17,6 +17,7 @@ use anyhow::Context as _;
use async_openai::types::FinishReason; use async_openai::types::FinishReason;
use dynamo_llm::model_card::model::ModelDeploymentCard; use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::request_template::RequestTemplate;
use dynamo_llm::types::openai::chat_completions::{ use dynamo_llm::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
...@@ -68,6 +69,7 @@ pub async fn run( ...@@ -68,6 +69,7 @@ pub async fn run(
maybe_card: Option<ModelDeploymentCard>, maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf, input_jsonl: PathBuf,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory // Check if the path exists and is a directory
...@@ -109,6 +111,7 @@ pub async fn run( ...@@ -109,6 +111,7 @@ pub async fn run(
tracing::info!("Timer start."); tracing::info!("Timer start.");
let start = Instant::now(); let start = Instant::now();
let mut lines = buffered_input.lines(); let mut lines = buffered_input.lines();
let template: Option<Arc<RequestTemplate>> = template.map(Arc::new);
while let Ok(Some(line)) = lines.next_line().await { while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() { if cancel_token.is_cancelled() {
break; break;
...@@ -132,16 +135,24 @@ pub async fn run( ...@@ -132,16 +135,24 @@ pub async fn run(
let tokens_out = tokens_out.clone(); let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone(); let done_entries_tx = done_entries_tx.clone();
let service_name_ref = service_name_ref.clone(); let service_name_ref = service_name_ref.clone();
let template_clone = template.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let local_start = Instant::now(); let local_start = Instant::now();
let response = let response = match evaluate(
match evaluate(request_id, service_name_ref.as_str(), engine, &mut entry).await { request_id,
Ok(r) => r, service_name_ref.as_str(),
Err(err) => { engine,
tracing::error!(%err, entry.text, "Failed evaluating prompt"); &mut entry,
return; template_clone,
} )
}; .await
{
Ok(r) => r,
Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt");
return;
}
};
let local_elapsed = Instant::now() - local_start; let local_elapsed = Instant::now() - local_start;
entry.elapsed_ms = local_elapsed.as_millis() as usize; entry.elapsed_ms = local_elapsed.as_millis() as usize;
...@@ -202,6 +213,7 @@ async fn evaluate( ...@@ -202,6 +213,7 @@ async fn evaluate(
service_name: &str, service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
entry: &mut Entry, entry: &mut Entry,
template: Option<Arc<RequestTemplate>>,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let user_message = async_openai::types::ChatCompletionRequestMessage::User( let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { async_openai::types::ChatCompletionRequestUserMessage {
...@@ -213,9 +225,18 @@ async fn evaluate( ...@@ -213,9 +225,18 @@ async fn evaluate(
); );
let inner = async_openai::types::CreateChatCompletionRequestArgs::default() let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message]) .messages(vec![user_message])
.model(service_name) .model(
template
.as_ref()
.map_or_else(|| service_name.to_string(), |t| t.model.clone()),
)
.stream(true) .stream(true)
.max_completion_tokens(MAX_TOKENS) .max_completion_tokens(
template
.as_ref()
.map_or(MAX_TOKENS, |t| t.max_completion_tokens),
)
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.build()?; .build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None }; let req = NvCreateChatCompletionRequest { inner, nvext: None };
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
......
...@@ -21,6 +21,7 @@ use dynamo_llm::{ ...@@ -21,6 +21,7 @@ use dynamo_llm::{
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
http::service::{discovery, service_v2}, http::service::{discovery, service_v2},
model_type::ModelType, model_type::ModelType,
request_template::RequestTemplate,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
...@@ -35,11 +36,13 @@ pub async fn run( ...@@ -35,11 +36,13 @@ pub async fn run(
runtime: Runtime, runtime: Runtime,
flags: Flags, flags: Flags,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder() let http_service = service_v2::HttpService::builder()
.port(flags.http_port) .port(flags.http_port)
.enable_chat_endpoints(true) .enable_chat_endpoints(true)
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.with_request_template(template)
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic(endpoint) => { EngineConfig::Dynamic(endpoint) => {
......
...@@ -22,7 +22,7 @@ use futures::StreamExt; ...@@ -22,7 +22,7 @@ use futures::StreamExt;
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
use crate::input::common; use crate::input::common;
use crate::{EngineConfig, Flags}; use crate::{EngineConfig, Flags, RequestTemplate};
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this /// TODO: Cmd line flag to overwrite this
...@@ -33,6 +33,7 @@ pub async fn run( ...@@ -33,6 +33,7 @@ pub async fn run(
flags: Flags, flags: Flags,
single_prompt: Option<String>, single_prompt: Option<String>,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
let (service_name, engine, inspect_template): ( let (service_name, engine, inspect_template): (
...@@ -46,6 +47,7 @@ pub async fn run( ...@@ -46,6 +47,7 @@ pub async fn run(
engine, engine,
single_prompt, single_prompt,
inspect_template, inspect_template,
template,
) )
.await .await
} }
...@@ -56,6 +58,7 @@ async fn main_loop( ...@@ -56,6 +58,7 @@ async fn main_loop(
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
mut initial_prompt: Option<String>, mut initial_prompt: Option<String>,
_inspect_template: bool, _inspect_template: bool,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if initial_prompt.is_none() { if initial_prompt.is_none() {
tracing::info!("Ctrl-c to exit"); tracing::info!("Ctrl-c to exit");
...@@ -101,14 +104,21 @@ async fn main_loop( ...@@ -101,14 +104,21 @@ async fn main_loop(
}, },
); );
messages.push(user_message); messages.push(user_message);
// Request // Request
let inner = async_openai::types::CreateChatCompletionRequestArgs::default() let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(messages.clone()) .messages(messages.clone())
.model(service_name) .model(
template
.as_ref()
.map_or_else(|| service_name.to_string(), |t| t.model.clone()),
)
.stream(true) .stream(true)
.max_completion_tokens(MAX_TOKENS) .max_completion_tokens(
.temperature(0.7) template
.as_ref()
.map_or(MAX_TOKENS, |t| t.max_completion_tokens),
)
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.n(1) // only generate one response .n(1) // only generate one response
.build()?; .build()?;
let nvext = NvExt { let nvext = NvExt {
......
...@@ -30,6 +30,7 @@ mod input; ...@@ -30,6 +30,7 @@ mod input;
#[cfg(any(feature = "vllm", feature = "sglang"))] #[cfg(any(feature = "vllm", feature = "sglang"))]
mod net; mod net;
mod opt; mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output}; pub use opt::{Input, Output};
/// How we identify a namespace/component/endpoint URL. /// How we identify a namespace/component/endpoint URL.
...@@ -206,6 +207,14 @@ pub async fn run( ...@@ -206,6 +207,14 @@ pub async fn run(
#[cfg(any(feature = "vllm", feature = "sglang"))] #[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
let template = if let Some(path) = flags.request_template.as_ref() {
let template = RequestTemplate::load(path)?;
tracing::debug!("Using request template: {template:?}");
Some(template)
} else {
None
};
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
Output::EchoFull => { Output::EchoFull => {
...@@ -474,19 +483,33 @@ pub async fn run( ...@@ -474,19 +483,33 @@ pub async fn run(
match in_opt { match in_opt {
Input::Http => { Input::Http => {
crate::input::http::run(runtime.clone(), flags, engine_config).await?; crate::input::http::run(runtime.clone(), flags, engine_config, template).await?;
} }
Input::Text => { Input::Text => {
crate::input::text::run(runtime.clone(), flags, None, engine_config).await?; crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?;
} }
Input::Stdin => { Input::Stdin => {
let mut prompt = String::new(); let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap(); std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(runtime.clone(), flags, Some(prompt), engine_config).await?; crate::input::text::run(
runtime.clone(),
flags,
Some(prompt),
engine_config,
template,
)
.await?;
} }
Input::Batch(path) => { Input::Batch(path) => {
crate::input::batch::run(runtime.clone(), flags, maybe_card, path, engine_config) crate::input::batch::run(
.await?; runtime.clone(),
flags,
maybe_card,
path,
engine_config,
template,
)
.await?;
} }
Input::Endpoint(path) => { Input::Endpoint(path) => {
let Some(dyn_input) = dyn_input else { let Some(dyn_input) = dyn_input else {
......
...@@ -43,6 +43,7 @@ use super::{ ...@@ -43,6 +43,7 @@ use super::{
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse, chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
}; };
use crate::request_template::RequestTemplate;
use crate::types::{ use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest}, openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
Annotated, Annotated,
...@@ -219,12 +220,26 @@ async fn completions( ...@@ -219,12 +220,26 @@ async fn completions(
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn chat_completions( async fn chat_completions(
State(state): State<Arc<DeploymentState>>, State((state, template)): State<(Arc<DeploymentState>, Option<RequestTemplate>)>,
Json(request): Json<NvCreateChatCompletionRequest>, Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
// Apply template values if present
if let Some(template) = template {
if request.inner.model.is_empty() {
request.inner.model = template.model.clone();
}
if request.inner.temperature.unwrap_or(0.0) == 0.0 {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_completion_tokens.unwrap_or(0) == 0 {
request.inner.max_completion_tokens = Some(template.max_completion_tokens);
}
}
tracing::trace!("Received chat completions request: {:?}", request.inner);
// todo - extract distributed tracing id and context id from headers // todo - extract distributed tracing id and context id from headers
let request_id = uuid::Uuid::new_v4().to_string(); let request_id = uuid::Uuid::new_v4().to_string();
...@@ -512,13 +527,14 @@ pub fn completions_router( ...@@ -512,13 +527,14 @@ pub fn completions_router(
/// If not path is provided, the default path is `/v1/chat/completions` /// If not path is provided, the default path is `/v1/chat/completions`
pub fn chat_completions_router( pub fn chat_completions_router(
state: Arc<DeploymentState>, state: Arc<DeploymentState>,
template: Option<RequestTemplate>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/chat/completions".to_string()); let path = path.unwrap_or("/v1/chat/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(chat_completions)) .route(&path, post(chat_completions))
.with_state(state); .with_state((state, template));
(vec![doc], router) (vec![doc], router)
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
use super::metrics; use super::metrics;
use super::ModelManager; use super::ModelManager;
use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
...@@ -44,6 +45,9 @@ pub struct HttpServiceConfig { ...@@ -44,6 +45,9 @@ pub struct HttpServiceConfig {
#[builder(default = "true")] #[builder(default = "true")]
enable_cmpl_endpoints: bool, enable_cmpl_endpoints: bool,
#[builder(default = "None")]
request_template: Option<RequestTemplate>,
} }
impl HttpService { impl HttpService {
...@@ -91,6 +95,7 @@ impl HttpServiceConfigBuilder { ...@@ -91,6 +95,7 @@ impl HttpServiceConfigBuilder {
model_manager.metrics().register(&registry)?; model_manager.metrics().register(&registry)?;
let mut router = axum::Router::new(); let mut router = axum::Router::new();
let mut all_docs = Vec::new(); let mut all_docs = Vec::new();
let mut routes = vec![ let mut routes = vec![
...@@ -101,6 +106,7 @@ impl HttpServiceConfigBuilder { ...@@ -101,6 +106,7 @@ impl HttpServiceConfigBuilder {
if config.enable_chat_endpoints { if config.enable_chat_endpoints {
routes.push(super::openai::chat_completions_router( routes.push(super::openai::chat_completions_router(
model_manager.state(), model_manager.state(),
config.request_template,
None, None,
)); ));
} }
...@@ -129,4 +135,9 @@ impl HttpServiceConfigBuilder { ...@@ -129,4 +135,9 @@ impl HttpServiceConfigBuilder {
host: config.host, host: config.host,
}) })
} }
pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
self.request_template = Some(request_template);
self
}
} }
...@@ -31,6 +31,7 @@ pub mod model_type; ...@@ -31,6 +31,7 @@ pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
pub mod protocols; pub mod protocols;
pub mod recorder; pub mod recorder;
pub mod request_template;
pub mod tokenizers; pub mod tokenizers;
pub mod tokens; pub mod tokens;
pub mod types; pub mod types;
......
// SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RequestTemplate {
pub model: String,
pub temperature: f32,
pub max_completion_tokens: u32,
}
impl RequestTemplate {
pub fn load(path: &Path) -> Result<Self> {
let template = std::fs::read_to_string(path)?;
let template: Self = serde_json::from_str(&template)?;
Ok(template)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment