Unverified Commit adad2ecd authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

feat: Add request template support for default inference parameters (#841)

Adds support for specifying default request parameters through a json template file that can be applied across all inference requests. This enables consistent parameter settings while still allowing per-request overrides.

Changes:
- Add --request-template CLI flag to specify template file path
- Integrate template support in HTTP, batch and text input modes
- Template values can be overridden by individual request parameters
- Example template.json:
```
{
    "model": "Qwen2.5-3B-Instruct",
    "temperature": 0.7,
    "max_completion_tokens": 4096
}
```
parent 904730b9
......@@ -126,6 +126,17 @@ pub struct Flags {
#[arg(long)]
pub extra_engine_args: Option<PathBuf>,
/// Path to a JSON file containing default request fields.
/// These fields will be merged with each request, but can be overridden by the request.
/// Example file contents:
/// {
/// "model": "Qwen2.5-3B-Instruct",
/// "temperature": 0.7,
/// "max_completion_tokens": 4096
/// }
#[arg(long)]
pub request_template: Option<PathBuf>,
/// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......
......@@ -17,6 +17,7 @@ use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::request_template::RequestTemplate;
use dynamo_llm::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
......@@ -68,6 +69,7 @@ pub async fn run(
maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory
......@@ -109,6 +111,7 @@ pub async fn run(
tracing::info!("Timer start.");
let start = Instant::now();
let mut lines = buffered_input.lines();
let template: Option<Arc<RequestTemplate>> = template.map(Arc::new);
while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() {
break;
......@@ -132,10 +135,18 @@ pub async fn run(
let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone();
let service_name_ref = service_name_ref.clone();
let template_clone = template.clone();
let handle = tokio::spawn(async move {
let local_start = Instant::now();
let response =
match evaluate(request_id, service_name_ref.as_str(), engine, &mut entry).await {
let response = match evaluate(
request_id,
service_name_ref.as_str(),
engine,
&mut entry,
template_clone,
)
.await
{
Ok(r) => r,
Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt");
......@@ -202,6 +213,7 @@ async fn evaluate(
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
entry: &mut Entry,
template: Option<Arc<RequestTemplate>>,
) -> anyhow::Result<String> {
let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
......@@ -213,9 +225,18 @@ async fn evaluate(
);
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message])
.model(service_name)
.model(
template
.as_ref()
.map_or_else(|| service_name.to_string(), |t| t.model.clone()),
)
.stream(true)
.max_completion_tokens(MAX_TOKENS)
.max_completion_tokens(
template
.as_ref()
.map_or(MAX_TOKENS, |t| t.max_completion_tokens),
)
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None };
let mut stream = engine.generate(Context::new(req)).await?;
......
......@@ -21,6 +21,7 @@ use dynamo_llm::{
engines::StreamingEngineAdapter,
http::service::{discovery, service_v2},
model_type::ModelType,
request_template::RequestTemplate,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
......@@ -35,11 +36,13 @@ pub async fn run(
runtime: Runtime,
flags: Flags,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder()
.port(flags.http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.with_request_template(template)
.build()?;
match engine_config {
EngineConfig::Dynamic(endpoint) => {
......
......@@ -22,7 +22,7 @@ use futures::StreamExt;
use std::io::{ErrorKind, Write};
use crate::input::common;
use crate::{EngineConfig, Flags};
use crate::{EngineConfig, Flags, RequestTemplate};
/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
......@@ -33,6 +33,7 @@ pub async fn run(
flags: Flags,
single_prompt: Option<String>,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let (service_name, engine, inspect_template): (
......@@ -46,6 +47,7 @@ pub async fn run(
engine,
single_prompt,
inspect_template,
template,
)
.await
}
......@@ -56,6 +58,7 @@ async fn main_loop(
engine: OpenAIChatCompletionsStreamingEngine,
mut initial_prompt: Option<String>,
_inspect_template: bool,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
if initial_prompt.is_none() {
tracing::info!("Ctrl-c to exit");
......@@ -101,14 +104,21 @@ async fn main_loop(
},
);
messages.push(user_message);
// Request
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(messages.clone())
.model(service_name)
.model(
template
.as_ref()
.map_or_else(|| service_name.to_string(), |t| t.model.clone()),
)
.stream(true)
.max_completion_tokens(MAX_TOKENS)
.temperature(0.7)
.max_completion_tokens(
template
.as_ref()
.map_or(MAX_TOKENS, |t| t.max_completion_tokens),
)
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.n(1) // only generate one response
.build()?;
let nvext = NvExt {
......
......@@ -30,6 +30,7 @@ mod input;
#[cfg(any(feature = "vllm", feature = "sglang"))]
mod net;
mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output};
/// How we identify a namespace/component/endpoint URL.
......@@ -206,6 +207,14 @@ pub async fn run(
#[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
let template = if let Some(path) = flags.request_template.as_ref() {
let template = RequestTemplate::load(path)?;
tracing::debug!("Using request template: {template:?}");
Some(template)
} else {
None
};
// Create the engine matching `out`
let engine_config = match out_opt {
Output::EchoFull => {
......@@ -474,18 +483,32 @@ pub async fn run(
match in_opt {
Input::Http => {
crate::input::http::run(runtime.clone(), flags, engine_config).await?;
crate::input::http::run(runtime.clone(), flags, engine_config, template).await?;
}
Input::Text => {
crate::input::text::run(runtime.clone(), flags, None, engine_config).await?;
crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(runtime.clone(), flags, Some(prompt), engine_config).await?;
crate::input::text::run(
runtime.clone(),
flags,
Some(prompt),
engine_config,
template,
)
.await?;
}
Input::Batch(path) => {
crate::input::batch::run(runtime.clone(), flags, maybe_card, path, engine_config)
crate::input::batch::run(
runtime.clone(),
flags,
maybe_card,
path,
engine_config,
template,
)
.await?;
}
Input::Endpoint(path) => {
......
......@@ -43,6 +43,7 @@ use super::{
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
};
use crate::request_template::RequestTemplate;
use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
Annotated,
......@@ -219,12 +220,26 @@ async fn completions(
/// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)]
async fn chat_completions(
State(state): State<Arc<DeploymentState>>,
Json(request): Json<NvCreateChatCompletionRequest>,
State((state, template)): State<(Arc<DeploymentState>, Option<RequestTemplate>)>,
Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready
check_ready(&state)?;
// Apply template values if present
if let Some(template) = template {
if request.inner.model.is_empty() {
request.inner.model = template.model.clone();
}
if request.inner.temperature.unwrap_or(0.0) == 0.0 {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_completion_tokens.unwrap_or(0) == 0 {
request.inner.max_completion_tokens = Some(template.max_completion_tokens);
}
}
tracing::trace!("Received chat completions request: {:?}", request.inner);
// todo - extract distributed tracing id and context id from headers
let request_id = uuid::Uuid::new_v4().to_string();
......@@ -512,13 +527,14 @@ pub fn completions_router(
/// If not path is provided, the default path is `/v1/chat/completions`
pub fn chat_completions_router(
state: Arc<DeploymentState>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/chat/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(chat_completions))
.with_state(state);
.with_state((state, template));
(vec![doc], router)
}
......
......@@ -15,6 +15,7 @@
use super::metrics;
use super::ModelManager;
use crate::request_template::RequestTemplate;
use anyhow::Result;
use derive_builder::Builder;
use tokio::task::JoinHandle;
......@@ -44,6 +45,9 @@ pub struct HttpServiceConfig {
#[builder(default = "true")]
enable_cmpl_endpoints: bool,
#[builder(default = "None")]
request_template: Option<RequestTemplate>,
}
impl HttpService {
......@@ -91,6 +95,7 @@ impl HttpServiceConfigBuilder {
model_manager.metrics().register(&registry)?;
let mut router = axum::Router::new();
let mut all_docs = Vec::new();
let mut routes = vec![
......@@ -101,6 +106,7 @@ impl HttpServiceConfigBuilder {
if config.enable_chat_endpoints {
routes.push(super::openai::chat_completions_router(
model_manager.state(),
config.request_template,
None,
));
}
......@@ -129,4 +135,9 @@ impl HttpServiceConfigBuilder {
host: config.host,
})
}
pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
self.request_template = Some(request_template);
self
}
}
......@@ -31,6 +31,7 @@ pub mod model_type;
pub mod preprocessor;
pub mod protocols;
pub mod recorder;
pub mod request_template;
pub mod tokenizers;
pub mod tokens;
pub mod types;
......
// SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RequestTemplate {
pub model: String,
pub temperature: f32,
pub max_completion_tokens: u32,
}
impl RequestTemplate {
pub fn load(path: &Path) -> Result<Self> {
let template = std::fs::read_to_string(path)?;
let template: Self = serde_json::from_str(&template)?;
Ok(template)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment