Unverified Commit 124ecd98 authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

fix: reject prompts exceeding max_seq_len with HTTP 400 (#6635)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent dc9f80b3
...@@ -55,6 +55,7 @@ async def _register_model_with_runtime_config( ...@@ -55,6 +55,7 @@ async def _register_model_with_runtime_config(
endpoint, endpoint,
server_args.model_path, server_args.model_path,
server_args.served_model_name, server_args.served_model_name,
context_length=server_args.context_length,
kv_cache_block_size=server_args.page_size, kv_cache_block_size=server_args.page_size,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=dynamo_args.custom_jinja_template, custom_template_path=dynamo_args.custom_jinja_template,
......
...@@ -456,6 +456,7 @@ async def init_llm_worker( ...@@ -456,6 +456,7 @@ async def init_llm_worker(
endpoint, endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
context_length=config.max_seq_len,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template, custom_template_path=config.custom_jinja_template,
......
...@@ -514,7 +514,7 @@ async def register_vllm_model( ...@@ -514,7 +514,7 @@ async def register_vllm_model(
generate_endpoint, generate_endpoint,
config: Config, config: Config,
engine_client: AsyncLLM, engine_client: AsyncLLM,
vllm_config, vllm_config: VllmConfig,
): ):
""" """
Helper function to register a vLLM model with runtime configuration. Helper function to register a vLLM model with runtime configuration.
...@@ -577,6 +577,7 @@ async def register_vllm_model( ...@@ -577,6 +577,7 @@ async def register_vllm_model(
generate_endpoint, generate_endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
context_length=vllm_config.model_config.max_model_len,
kv_cache_block_size=runtime_values["block_size"], kv_cache_block_size=runtime_values["block_size"],
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template, custom_template_path=config.custom_jinja_template,
......
...@@ -209,6 +209,20 @@ impl ErrorMessage { ...@@ -209,6 +209,20 @@ impl ErrorMessage {
); );
} }
// Check for DynamoError with InvalidArgument → HTTP 400
if let Some(dynamo_err) = err.downcast_ref::<dynamo_runtime::error::DynamoError>()
&& dynamo_err.error_type() == dynamo_runtime::error::ErrorType::InvalidArgument
{
return (
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
message: dynamo_err.message().to_string(),
error_type: map_error_code_to_error_type(StatusCode::BAD_REQUEST),
code: StatusCode::BAD_REQUEST.as_u16(),
}),
);
}
// Then check for HttpError // Then check for HttpError
match err.downcast::<HttpError>() { match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error), Ok(http_error) => ErrorMessage::from_http_error(http_error),
......
...@@ -17,10 +17,12 @@ pub mod speculative_prefill; ...@@ -17,10 +17,12 @@ pub mod speculative_prefill;
pub mod tools; pub mod tools;
use anyhow::Context; use anyhow::Context;
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat, ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat,
}; };
use dynamo_runtime::error::{DynamoError, ErrorType};
use futures::Stream; use futures::Stream;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter; use prompt::OAIPromptFormatter;
...@@ -146,6 +148,8 @@ pub struct OpenAIPreprocessor { ...@@ -146,6 +148,8 @@ pub struct OpenAIPreprocessor {
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig, runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
media_loader: Option<MediaLoader>, media_loader: Option<MediaLoader>,
/// Max context length (in tokens) this model can handle, from ModelDeploymentCard
context_length: u32,
} }
impl OpenAIPreprocessor { impl OpenAIPreprocessor {
...@@ -185,6 +189,8 @@ impl OpenAIPreprocessor { ...@@ -185,6 +189,8 @@ impl OpenAIPreprocessor {
None => None, None => None,
}; };
let context_length = mdc.context_length;
Ok(Arc::new(Self { Ok(Arc::new(Self {
formatter, formatter,
tokenizer, tokenizer,
...@@ -194,6 +200,7 @@ impl OpenAIPreprocessor { ...@@ -194,6 +200,7 @@ impl OpenAIPreprocessor {
runtime_config, runtime_config,
tool_call_parser, tool_call_parser,
media_loader, media_loader,
context_length,
})) }))
} }
/// Encode a string to it's tokens /// Encode a string to it's tokens
...@@ -437,16 +444,19 @@ impl OpenAIPreprocessor { ...@@ -437,16 +444,19 @@ impl OpenAIPreprocessor {
tracker: Option<&RequestTracker>, tracker: Option<&RequestTracker>,
) -> Result<HashMap<String, String>> { ) -> Result<HashMap<String, String>> {
let mut annotations = HashMap::new(); let mut annotations = HashMap::new();
let mut token_count: Option<usize> = None;
// match request type before any conversion/processing // match request type before any conversion/processing
match request.prompt_input_type() { match request.prompt_input_type() {
PromptInput::Tokens(_) => { PromptInput::Tokens(_) => {
if let Some(token_input) = request.extract_tokens() { if let Some(token_input) = request.extract_tokens() {
match token_input { match token_input {
TokenInput::Single(tokens) => { TokenInput::Single(tokens) => {
token_count = Some(tokens.len());
builder.token_ids(tokens); builder.token_ids(tokens);
} }
TokenInput::Batch(token_batches) => { TokenInput::Batch(token_batches) => {
if token_batches.len() == 1 { if token_batches.len() == 1 {
token_count = Some(token_batches[0].len());
builder.token_ids(token_batches[0].clone()); builder.token_ids(token_batches[0].clone());
} else { } else {
bail!( bail!(
...@@ -511,12 +521,15 @@ impl OpenAIPreprocessor { ...@@ -511,12 +521,15 @@ impl OpenAIPreprocessor {
); );
} }
token_count = Some(tokens_vec.len());
builder.token_ids(tokens_vec); builder.token_ids(tokens_vec);
} }
TextInput::Batch(texts) => { TextInput::Batch(texts) => {
if texts.len() == 1 { if texts.len() == 1 {
let encoding = self.encode_with_timing(&texts[0], tracker)?; let encoding = self.encode_with_timing(&texts[0], tracker)?;
builder.token_ids(encoding.token_ids().to_vec()); let tokens = encoding.token_ids().to_vec();
token_count = Some(tokens.len());
builder.token_ids(tokens);
} else { } else {
bail!( bail!(
"Batch text input not supported for more than one text in requests (got {})", "Batch text input not supported for more than one text in requests (got {})",
...@@ -528,9 +541,38 @@ impl OpenAIPreprocessor { ...@@ -528,9 +541,38 @@ impl OpenAIPreprocessor {
} }
} }
} }
// Validate prompt token count against model's context length
if let Some(count) = token_count {
Self::validate_token_count(count, self.context_length)?;
}
Ok(annotations) Ok(annotations)
} }
/// Validate that the prompt token count does not consume the model's entire context length.
/// Returns an error if the prompt leaves no room for output tokens.
fn validate_token_count(token_count: usize, context_length: u32) -> Result<()> {
let max_len = context_length as usize;
// max_len == 0 means context_length was not configured (model_card.rs defaults
// to 0 when max_position_embeddings is absent), so skip validation.
// Use >= because context_length is the total budget (input + output): if the
// prompt alone fills it, there is zero room for output tokens.
if max_len > 0 && token_count >= max_len {
return Err(DynamoError::builder()
.error_type(ErrorType::InvalidArgument)
.message(format!(
"This model's maximum context length is {} tokens. \
However, your messages resulted in {} tokens. \
Please reduce the length of the messages.",
max_len, token_count,
))
.build()
.into());
}
Ok(())
}
fn encode_with_timing( fn encode_with_timing(
&self, &self,
prompt: &str, prompt: &str,
......
...@@ -622,3 +622,115 @@ async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) { ...@@ -622,3 +622,115 @@ async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) {
} }
} }
} }
mod context_length_validation {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use dynamo_runtime::error::{DynamoError, ErrorType};
// mock-llama has a chat_template in tokenizer_config.json (required for preprocessing)
const MODEL_PATH: &str = "tests/data/sample-models/mock-llama-3.1-8b-instruct";
fn make_chat_request(message: &str, model: &str) -> NvCreateChatCompletionRequest {
let messages: Vec<dynamo_async_openai::types::ChatCompletionRequestMessage> =
serde_json::from_str(message).unwrap();
let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model(model)
.messages(messages)
.build()
.unwrap();
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
#[tokio::test]
async fn test_prompt_exceeding_context_length_returns_400() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
// Set a very small context length so even a short prompt exceeds it
mdc.context_length = 5;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
// Should fail with a DynamoError with InvalidArgument type
let err = result.expect_err("should reject prompt exceeding context_length");
let dynamo_err = err
.downcast_ref::<DynamoError>()
.expect("error should be DynamoError");
assert_eq!(dynamo_err.error_type(), ErrorType::InvalidArgument);
assert!(
dynamo_err
.message()
.contains("maximum context length is 5 tokens"),
"error message should state the context limit, got: {}",
dynamo_err.message()
);
assert!(
dynamo_err.message().contains("Please reduce the length"),
"error message should tell user what to do, got: {}",
dynamo_err.message()
);
}
#[tokio::test]
async fn test_prompt_exactly_at_context_length_returns_400() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
// First, preprocess with a large context_length to discover the token count
mdc.context_length = 131072;
let preprocessor = OpenAIPreprocessor::new(mdc.clone()).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let (preprocessed, _) = preprocessor
.preprocess_request(&request, None)
.await
.unwrap();
let token_count = preprocessed.token_ids.len() as u32;
// Now set context_length to exactly the token count — no room for output
mdc.context_length = token_count;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
// Should reject: prompt fills entire context, no room for output
let err = result.expect_err("should reject prompt that fills entire context_length");
let dynamo_err = err
.downcast_ref::<DynamoError>()
.expect("error should be DynamoError");
assert_eq!(dynamo_err.error_type(), ErrorType::InvalidArgument);
}
#[tokio::test]
async fn test_context_length_zero_skips_validation() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
// context_length = 0 means unconfigured, should skip validation
mdc.context_length = 0;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
assert!(result.is_ok(), "context_length=0 should skip validation");
}
}
...@@ -43,6 +43,8 @@ use std::fmt; ...@@ -43,6 +43,8 @@ use std::fmt;
pub enum ErrorType { pub enum ErrorType {
/// Uncategorized or unknown error. /// Uncategorized or unknown error.
Unknown, Unknown,
/// The request contains invalid input (e.g., prompt exceeds context length).
InvalidArgument,
/// Failed to establish a connection to a remote worker. /// Failed to establish a connection to a remote worker.
CannotConnect, CannotConnect,
/// An established connection was lost unexpectedly. /// An established connection was lost unexpectedly.
...@@ -57,6 +59,7 @@ impl fmt::Display for ErrorType { ...@@ -57,6 +59,7 @@ impl fmt::Display for ErrorType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
ErrorType::Unknown => write!(f, "Unknown"), ErrorType::Unknown => write!(f, "Unknown"),
ErrorType::InvalidArgument => write!(f, "InvalidArgument"),
ErrorType::CannotConnect => write!(f, "CannotConnect"), ErrorType::CannotConnect => write!(f, "CannotConnect"),
ErrorType::Disconnected => write!(f, "Disconnected"), ErrorType::Disconnected => write!(f, "Disconnected"),
ErrorType::ConnectionTimeout => write!(f, "ConnectionTimeout"), ErrorType::ConnectionTimeout => write!(f, "ConnectionTimeout"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment